In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sys, time
sys.path.insert(1, '../discrete_mixflows/')
from discrete_mixflows import *
from gibbs import *
from concrete import *

plt.rcParams.update({'figure.max_open_warning': 0})
plt.rcParams["figure.figsize"]=15,7.5
plt.rcParams.update({'font.size': 24})

In [2]:
########################
########################
# target specification #
########################
########################
np.random.seed(2023)
K1=4
K2=5
prbs=np.random.rand(K1,K2)
prbs=prbs/np.sum(prbs)
def lp(x,axis=None):
    # compute the univariate log joint and conditional target pmfs
    #
    # inputs:
    #    x    : (2,d) array with state values
    #    axis : int, full conditional to calculate; returns joint if None
    # outputs:
    #   ext_lprb : if axis is None, (d,) array with log joint; else, (d,K_{axis+1}) array with d conditionals 
    
    ext_lprb=np.log(np.moveaxis(np.repeat(prbs[:,:,np.newaxis],x.shape[1],axis=-1),2,0)) # stack d copies of logp vertically
    if axis==None: return np.squeeze(ext_lprb[np.arange(0,x.shape[1]),x[0,:],x[1,:]]) # for each d, return lp(x)
    if axis==0: return ext_lprb[np.arange(0,x.shape[1]),:,x[1,:]]
    if axis==1: return ext_lprb[np.arange(0,x.shape[1]),x[0,:],:]
    raise Exception("Axis out of bounds - there aren't that many variables")

In [4]:
lprbs=np.log(prbs)

In [41]:
def to_2d(x,K2):
    """
    Each x_i is an integer in [0,K1*K2)
    Converts to tuples in [0,K1]x[0,K2]
    
    Input:
        x  : (d,) array, flattened array
    Output:
        x_ : (2,d) array, unflattened array
    """
    return np.vstack((x//K2,x%K2))

def to_1d(x,K2):
    """
    Each x_ij is a tuple in [0,K1]x[0,K2]
    Flattens to integers in [0,K1*K2)
    
    Input:
        x  : (2,d) array, unflattened array
    Output:
        x_ : (d,) array, flatened array
    """
    return x[0,:]*K2+x[1,:]

In [42]:
init=np.arange(0,K1*K2)
print('Initial: '+str(init))
twodim=to_2d(init,K2)
print('Flattened: '+str(twodim))
onedim=to_1d(twodim,K2)
print('Final: '+str(onedim))

Initial: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
Flattened: [[0 0 0 0 0 1 1 1 1 1 2 2 2 2 2 3 3 3 3 3]
 [0 1 2 3 4 0 1 2 3 4 0 1 2 3 4 0 1 2 3 4]]
Final: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
