# Deterministic slice flows

Vectorized version of the functions in `density_evaluation`.

Consider a target density $\pi(x_{1:M},u_{1:M})=p(x_{1:M})1_{[0,1]}(u_{1:M})$ 
where we have access to the full conditionals $p_m$ 
and let $F_m,Q_m$ be the cdf and quantile functions of $p_m$.
We approximate $\pi$ with a variational ergodic flow
$$
    q_N=\frac{1}{N}\sum_{n=0}^N T^nq_0,
$$
where $q_0$ is a reference distribution and
$T=T_m\circ\cdots\circ T_1$ mimics Gibbs sampling, i.e.,
each map $T_m$ updates only $(x_m,u_m)\mapsto(x_m',u_m')$.
Specifically,
$$
\begin{pmatrix}x_m'\\u_m'\end{pmatrix}=
\begin{pmatrix}
    Q_m(\rho(x_m,u_m)+\xi\mod 1)\\ 
    \frac{1}{p_m(x_m')}((\rho(x_m,u_m)+\xi\mod 1)-F(x_m'))
\end{pmatrix},
$$
where $\rho(x,u)=F_m(x-1)+up_m(x)$ converts to proportions and $F_m(0)=0$ by convention.

In the notes,
I showed that the variational density can be evaluated in closed form
since the Jacobians of the continuous restriction correspond to density ratios. Specifically,
$$
    q_N(x_{1:M},u_{1:M})
    =\frac{1}{N}\sum_{n=0}^{N-1}
    q_0(T^{-n}(x_{1:M},u_{1:M})
    \prod_{j=1}^n\prod_{m=1}^M \frac{p_m(T^{-j+1}(x_{1:M},u_{1:M}))}{p_m(T^{-j}(x_{1:M},u_{1:M}))}.
$$

First we define all the necessary functions.

In [2]:
import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt

plt.rcParams.update({'figure.max_open_warning': 0})
plt.rcParams["figure.figsize"]=15,7.5
plt.rcParams.update({'font.size': 24})

In [258]:
########################################
########################################
# variational approximation functions
########################################
########################################
def lqN(x,u,N,lq0,lp,xi=np.pi/16):
    if N==1: return lq0(x,u)
    w=np.zeros(N)
    w[0]=lq0(x,u)
    LJ=0
    for n in range(N-1):
        sprbs=np.sum(lp(x))
        x,u=flow(x,u,1,lp,xi,direction='bwd')
        LJ=LJ+sprbs-np.sum(lp(x))
        w[n+1]=lq0(x,u)+LJ
    # end for
    return LogSumExp(w)-np.log(N)

def randqN(size,N,randq0):
    if N==1: return randq0(size)
    K=np.random.randint(low=0,high=N,size=size)
    x,u=randq0(size)
    for i in range(size): 
        tx,tu = flow(np.atleast_1d(x[i,...]),np.atleast_1d(u[i,...]),steps=K[i],lp=lp,xi=xi,direction='fwd')
        x[i,...]=tx
        u[i,...]=tu
    return x,u
    


########################################
########################################
# flow functions
########################################
########################################
def flow(x,u,steps,lp,xi=np.pi/16,direction='fwd'):
    #M=prbs.ndim
    M=x.shape[0]
    if steps==0: return x,u
    for t in range(steps):
        for m in range(M):
            m_idx = m if direction=='fwd'else M-m-1 # if in reverse, update starting from the end
            tmp_prbs=np.atleast_1d(np.exp(lp(x,axis=m_idx)))
            tx,tu=Tm(x[m_idx],u[m_idx],tmp_prbs/np.sum(tmp_prbs),xi,direction=direction)
            x[m_idx]=tx
            u[m_idx]=tu
        # end for
    # end for
    return x,u
        
    
def Tm(x,u,prbs,xi=np.pi/16,direction='fwd'):
    # compute Tm(x,u)
    #
    # inputs:
    #    x         : (d,) array, states of xm
    #    u         : (d,) array, values of um
    #    prbs      : (Km,) array, probabilities of Xm|X-m
    #    xi        : scalar, uniform shift
    #    direction : string, one of 'fwd' (forward map) or 'bwd' (backward map)
    #
    # outputs:
    #   xp : (d,) array, updated states xm'
    #   up : (d,) array, updated values um'
    
    if direction=='bwd': xi=-xi
    p=getp(x,u,prbs,xi)
    xp=quantile(p,prbs)
    up=(p-cdf(xp-1,prbs))/prbs[xp]
    return xp,up


def getp(x,u,prbs,xi=np.pi/16):
    # get proportion from current pair (xm,um)
    # equivalent to rho+xi mod 1 in paper
    #
    # inputs:
    #    x         : (d,) array, states of xm
    #    u         : (d,) array, values of um
    #    prbs      : (Km,) array, probabilities of Xm|X-m
    #    xi        : scalar, uniform shift
    #
    # outputs:
    #   p' : (d,) array, proportion and shifted states p'
    
    p=u*prbs[x]
    F=np.cumsum(prbs) # cdf
    p[x>0]=p[x>0]+F[x[x>0]-1] # vectorized "+prbs[:x] if x>0"
    return (p+xi)%1
    
    
########################################
########################################
# inference
########################################
########################################
def elbo(B,lp,N,M,lq0,randqN,randq0,xi=np.pi/16):
    tx,tu=randqN(B,N,randq0)
    elbos=np.zeros(B)
    for b in range(B): elbos[b]=lp(np.atleast_1d(tx[b]))-lqN(np.atleast_1d(tx[b]),np.atleast_1d(tu[b]),N,lq0,lp,xi)
    return np.mean(elbos)
        

########################################
########################################
# auxiliary functions
########################################
########################################
def LogSumExp(w):
    # LogSumExp trick
    #
    # inputs:
    #    w : (d,) array, exponents
    #
    # outputs:
    #    w' : (d,) array, log(sum(exp(w)))
    wmax = np.amax(w)
    return wmax + np.log(np.sum(np.exp(w-wmax)))

def cdf(x,prbs): 
    # cdf of x given prbs (vectorized): F(x)
    #
    # inputs:
    #    x    : (d,) array, states of xm
    #    prbs : (Km,) array, probabilities of Xm|X-m
    #
    # outputs:
    #   F(x) : (d,) array, cdf of X at each xi (F(x)_i=F(x_i))
    
    F=np.cumsum(prbs)
    return F[x]

def quantile(u,prbs): 
    # quantile function of u given prbs (vectorized)
    # via scipy stats, couldn't implement in native numpy
    #
    # inputs:
    #    u    : (d,) array, values of um
    #    prbs : (Km,) array, probabilities of Xm|X-m
    #
    # outputs:
    #   Q(x) : (d,) array, quantile of X at each ui (Q(u)_i=Q(u_i))
    myrv=stats.rv_discrete(values=(np.arange(0,prbs.shape[0]), prbs))
    return tmp.ppf(u).astype(int)


def gen_lp(prbs):
    # generate an iterable lp function given probs array prbs
    #
    # inputs:
    #    prbs : (K1,...,KM) array, probabilities
    #
    # outputs:
    #   my_lp : function, obtains joint and conditional probabilities
    #           my_lp(x)      -> joint at states x
    #           my_lp(x,axis) -> conditional of x_axis given x_{-axis}
    
    def my_lp(x,axis=None):
        if axis==None: return prbs[tuple(x)] # evaluate lp(x)
        # else return prbs[x_1,x_2,...,x_{m-1},:,x_{m+1},...,x_M] with m=axis
        tmp_prbs=np.ones(prbs.shape[axis]) # init uniform
        tmp_x=np.copy(x)
        for i in range(prbs.shape[axis]):
            tmp_x[axis]=i
            tmp_prbs[i]=prbs[tuple(tmp_x)] 
        # end for
        return tmp_prbs
    return my_lp

In [251]:
def Tm(x,u,prbs,xi=np.pi/16,direction='fwd'):
    if direction=='bwd': xi=-xi
    p=getp(x,u,prbs,xi)
    xp=quantile(p,prbs)
    up=(p-cdf(xp-1,prbs))/prbs[xp]
    return xp,up

In [252]:
np.random.seed(2022)
M=4
prbs=np.random.rand(M)
prbs=prbs/np.sum(prbs)

x1=np.array([0],dtype=int)
x2=np.array([2],dtype=int)
u1=np.array([0.999])
u2=np.array([0.25])

In [256]:
def Tm2(x,u,prbs,xi=np.pi/16,direction='fwd'):
    if direction=='bwd': xi=-xi
    p=getp(x,u,prbs,xi)
    xp=quantile2(p,prbs)
    up=(p-cdf2(xp-1,prbs))/prbs[xp]
    return xp,up

In [257]:
x=np.array([0,2],dtype=int)
u=np.array([0.999,0.25])
Tm2(x,u,prbs)

(array([1, 3]), array([0.26428438, 0.93778281]))