In [158]:
suppressPackageStartupMessages(library(extraDistr))
suppressPackageStartupMessages(library(distr))
set.seed(2024)

In [159]:
weight = 1.0
# .GlobalEnv$weight = 1.0
observations = NULL

In [160]:
## Utilities to make the distr library a bit nicer to use
p <- function(distribution, realization) {
  d(distribution)(realization) # return the PMF or density 
}

Bern = function(probability_to_get_one) {
  DiscreteDistribution(supp = 0:1, prob = c(1-probability_to_get_one, probability_to_get_one))
}

## Key functions called by simPPLe programs

# Use simulate(distribution) for unobserved random variables
simulate <- function(distribution) {
  r(distribution)(1) # sample once from the given distribution
}

observe = function(realization, distribution) {
  # `<<-` lets us modify variables that live in the global scope from inside a function
  weight <<- weight * p(distribution, realization) 
  observations <<- c(observations, realization)
}

In [161]:
# Example from lecture
my_ppl_1 = function(n_obs) {
  x = simulate(DiscreteDistribution(supp = 0:2))
  n_outcomes = length(0:2)
  p_x = p(DiscreteDistribution(supp = 0:2), x)
  prob_heads = x/K
  probs = rep(0, n_obs-1)
  for (i in (1:(n_obs-1))) {
    y = simulate(Bern(1 - prob_heads))
    observe(y, Bern(1 - prob_heads)) 
    if (y == 1) {
      probs[i] = 1 - prob_heads
    } else {
      probs[i] = prob_heads
    }
  }
  # return the path and associated info
  return(list(obs = c(x, observations), prob = c(p_x, probs), max_leaf = (n_outcomes-2)*(2^(n_obs-1))+2))
}


# Example from Exercise 01: bigger bag of more biased coins
my_ppl_2 = function(n_obs) {
  K = 9
  rho = seq(1,K+1,1) / sum(seq(1,K+1,1))
  dist_x = DiscreteDistribution(supp = 0:K, prob=rho)
  x = simulate(dist_x)
  n_outcomes = length(0:K)
  p_x = p(dist_x, x)
  prob_heads = rho[x+1]
  probs = rep(0, n_obs-1)
  for (i in (1:(n_obs-1))) {
    y = simulate(Bern(1 - prob_heads))
    observe(y, Bern(1 - prob_heads)) 
    if (y == 1) {
      probs[i] = 1 - prob_heads
    } else {
      probs[i] = prob_heads
    }
  }
  # return the path and associated info
  return(list(obs = c(x, observations), prob = c(p_x, probs), max_leaf = (n_outcomes-2)*(2^(n_obs-1))+2))
  
}


In [162]:
decision_tree = function(ppl_fn, n_level) {
    generated_paths = list() # hashmap for path: key=str(path), value=path
    n_leaf = 0
    complete = FALSE
    mermaid_code = cat("graph TD", "\n", "N[S]", "\n")
    visited = NULL # keeping track of the nodes visited so that avoid duplicate path-so-far
    while (!complete) {
        weight <<- 1
        observations <<- NULL
        path_obj = ppl_fn(n_level)
        path = path_obj$obs
        probs = path_obj$prob
        max_leaf = path_obj$max_leaf
        
        path_str = paste(path, collapse = "")  # Convert path to a string
        if (!path_str %in% names(generated_paths)) {
            generated_paths[[path_str]] = path  # store path
            n_leaf = n_leaf + 1

            # for each path, parse along the entries to update mermaid string
            # format: parent -->|prob| curr[X=0](or curr[Yi=1])
            parent_node = "N"
            visited = c(visited, parent_node)
            for (i in (1:nchar(path_str))) {
                curr_node = paste0("N", substring(path_str, 1, i))
                if (!curr_node %in% visited) {
                    mermaid_code = cat(mermaid_code, paste0(parent_node, "-->|", round(probs[i],2), "| "), curr_node, "[", substring(path_str,i,i), "]", "\n", sep="")
                } 
                parent_node = curr_node
                visited = c(visited, parent_node)
            }
        }
        # Stop if total number of leaves so far = theorectical num of leaves
        if (n_leaf == max_leaf) {
            complete = TRUE
        }
    }
    return (mermaid_code)
}

In [163]:
tree1 = decision_tree(my_ppl_1, 3)

graph TD 
 N[S] 
N-->|0.33| N0[0]
N0-->|1| N01[1]
N01-->|1| N011[1]
N-->|0.33| N2[2]
N2-->|0.78| N21[1]
N21-->|0.22| N210[0]
N-->|0.33| N1[1]
N1-->|0.11| N10[0]
N10-->|0.11| N100[0]
N1-->|0.89| N11[1]
N11-->|0.89| N111[1]
N21-->|0.78| N211[1]
N10-->|0.89| N101[1]


In [164]:
tree2 = decision_tree(my_ppl_2, 3)

graph TD 
 N[S] 
N-->|0.11| N5[5]
N5-->|0.89| N51[1]
N51-->|0.11| N510[0]
N51-->|0.89| N511[1]
N-->|0.15| N7[7]
N7-->|0.85| N71[1]
N71-->|0.85| N711[1]
N-->|0.09| N4[4]
N4-->|0.91| N41[1]
N41-->|0.91| N411[1]
N-->|0.16| N8[8]
N8-->|0.84| N81[1]
N81-->|0.84| N811[1]
N-->|0.18| N9[9]
N9-->|0.82| N91[1]
N91-->|0.82| N911[1]
N-->|0.07| N3[3]
N3-->|0.93| N31[1]
N31-->|0.93| N311[1]
N81-->|0.16| N810[0]
N-->|0.13| N6[6]
N6-->|0.13| N60[0]
N60-->|0.87| N601[1]
N41-->|0.09| N410[0]
N6-->|0.87| N61[1]
N61-->|0.87| N611[1]
N7-->|0.15| N70[0]
N70-->|0.85| N701[1]
N-->|0.05| N2[2]
N2-->|0.95| N21[1]
N21-->|0.95| N211[1]
N61-->|0.13| N610[0]
N-->|0.02| N0[0]
N0-->|0.98| N01[1]
N01-->|0.98| N011[1]
N91-->|0.18| N910[0]
N71-->|0.15| N710[0]
N9-->|0.18| N90[0]
N90-->|0.82| N901[1]
N31-->|0.07| N310[0]
N5-->|0.11| N50[0]
N50-->|0.89| N501[1]
N8-->|0.16| N80[0]
N80-->|0.16| N800[0]
N70-->|0.15| N700[0]
N3-->|0.07| N30[0]
N30-->|0.93| N301[1]
N-->|0.04| N1[1]
N1-->|0.96| N11[1]
N11-->|0.96| N111[1]
N80--