# Deterministic slice flows with integer arithmetic

Consider a target density $\pi(x,u)=p(x)1_{[0,1]}(u)$ and let $F$ be the cdf of $p$.
Here the idea is to update a pair $(x,u)$ via
$$
\begin{pmatrix}x\\u\end{pmatrix}=
\begin{pmatrix}
    F^{-1}(\rho(u,x)+\xi\mod 1)\\ 
    \frac{1}{p(x')}((\rho(u,x)+\xi\mod 1)-F(x'))
\end{pmatrix},
$$
where $\rho(u,x)=F(x-1)+up(x)$ converts to proportions and $F(0)=0$ by convention.

We consider a univariate and a bivariate examples, 
both taken from Trevor's ipynb.
First we define all the necessary functions.

In [1]:
import numpy as np
import scipy.stats as stats
import pandas as pd
import matplotlib.pyplot as plt

plt.rcParams.update({'figure.max_open_warning': 0})
plt.rcParams["figure.figsize"]=15,7.5
plt.rcParams.update({'font.size': 24})

In [2]:
# main functions
def update(x,u,prbs,xi=np.pi/16,direction='fwd'):
    if direction=='bwd': xi=-xi
    p=getp(x,u,prbs,xi)
    xp=quantile(p,prbs)
    up=(p-cdf(xp-1,prbs))/prbs[xp]
    return xp,up


def getp(x,u,prbs,xi=np.pi/16):
    p=u*prbs[x]
    if x>0:  p+=np.sum(prbs[:x])
    #return lcg_update(p*2e32,m=2e32,a=1664525.,c=1013904223.)/2e32
    return (p+xi)%1

def int_update(x,u,prbs,xi=np.pi/16,bigint=1e1,direction='fwd'):
    if direction=='bwd': xi=-xi 
    # integerize
    Mu=int(u*bigint)
    Mprbs=(prbs*bigint).astype(int)
        
    Mp=int_getp(x,Mu,Mprbs,xi,modulus=bigint)
    xp=quantile(Mp,Mprbs)
    Mup=M*(Mp-cdf(xp-1,Mprbs))/Mprbs[xp]
    return xp,Mup/M
    #return xp,up

def int_getp(x,u,prbs,xi=np.pi/16,modulus=1):
    p=int(u*prbs[x]/modulus)
    if x>0:  p+=np.sum(prbs[:x])
    #return lcg_update(p*2e32,m=2e32,a=1664525.,c=1013904223.)/2e32
    return (p+int(modulus*xi))%modulus
    
# auxiliary functions
def cdf(x,prbs): return np.sum(prbs[:(x+1)])
def quantile(u,prbs): return np.argmax(np.cumsum(prbs)>u)
def lcg_update(u,m,a=1.,c=0.): return (a*u+c)%m

## Univariate example

In [3]:
# generate the distribution
np.random.seed(2022)
K=4
prbs=np.random.rand(K)
prbs=np.array([0.047480609696438775, 0.40224879672133207, 0.23127020434512163, 0.31900038923710755]) # trevor's example
prbs=prbs/np.sum(prbs)
prbs

array([0.04748061, 0.4022488 , 0.2312702 , 0.31900039])

In [4]:
# demo
n_iters=10
M=1e8
xi=np.pi/16
x=np.zeros(n_iters+1,dtype=int)
u=np.zeros(n_iters+1)
u[0]=np.random.rand()
u[0]=0.10688645379435302 # trevor's example
u[0]=0.555

# forward map
print('Initial (u,x)=('+str(u[0])+','+str(x[0]+1)+')')
for it in range(n_iters):
    print(x[it],u[it])
    tmpx,tmpu=int_update(x[it],u[it],prbs,xi,bigint=M)
    x[it+1]=tmpx
    u[it+1]=tmpu
# end for
print('Final (u,x)=('+str(u[-1])+','+str(x[-1]+1)+')')

Initial (u,x)=(0.555,1)
0 0.555
1 0.4356027273568679
1 0.9237323001021333
2 0.7163523013341105
3 0.4098749662931436
0 0.1705831434312119
1 0.3902269538212906
1 0.8783565265665559
2 0.6374301142127261
3 0.3526576363325962
Final (u,x)=(0.9681726711422727,4)


In [5]:
# backward map
tu=u[-1]
tx=x[-1]
print('Final (u,x)=('+str(tu)+','+str(tx+1)+')')
for it in range(n_iters):
    print(tx,tu)
    tx,tu=int_update(tx,tu,prbs,xi,bigint=M,direction='bwd')
# end for
print('Initial (u,x)=('+str(tu)+','+str(tx+1)+')')

Final (u,x)=(0.9681726711422727,4)
3 0.9681726711422727
3 0.3526575736367461
2 0.6374299412548612
1 0.8783563774051377
1 0.39022675493939957
0 0.1705810373078689
3 0.4098745901180431
2 0.7163516959815834
1 0.9237319023383513
1 0.4356022798726132
Initial (u,x)=(0.554995935181948,1)
