# Simulated data

This notebook runs bartz on simulated data. It is meant to be run on a GPU. Use the following link to try it out on colab: [link](https://colab.research.google.com/github/bartz-org/bartz/blob/main/docs/examples/basic_simdata.ipynb)

The next cell installs bartz:

In [1]:
%pip install git+https://github.com/bartz-org/bartz@main

Collecting git+https://github.com/bartz-org/bartz@main
  Cloning https://github.com/bartz-org/bartz (to revision main) to /tmp/pip-req-build-5srkidke
  Running command git clone --filter=blob:none --quiet https://github.com/bartz-org/bartz /tmp/pip-req-build-5srkidke
  Resolved https://github.com/bartz-org/bartz to commit c6874eaea903ace75960c9f85e1f18404d82b425
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


The next cell defines configuration parameters for the script:

In [8]:
n_train = 100_000  # number of training points
p = 1000           # number of predictors/features
sigma = 0.1        # error standard deviation
n_test = 1000      # number of test points
n_tree = 10_000    # number of trees used by bartz

The next cell generates simulated data from a linear + quadratic test model that comes packaged with bartz:

In [2]:
from jax import random

from bartz.testing import gen_data

# list of independent seeds for random sampling
keys = list(random.split(random.key(2024_04_16_18_53), 2))

# simulate data with bartz's built-in testing data generating process
data = gen_data(
    keys.pop(),
    n=n_train + n_test,
    p=p,
    q=2,  #Â number of interactions, each predictor interacts with other q predictors in the quadratic term
    sigma2_eps=sigma**2,  # error variance
    sigma2_lin=0.5,  # linear term variance
    sigma2_quad=0.5,  # quadratic term variance
    k=1,  # number of outcomes
    lam=1.0,  # correlation between outcomes, unused in this case
)

# split data in train-test
train, test = data.split(n_train)

The next cell runs bartz:

In [3]:
from time import perf_counter

from bartz.BART import gbart

# clock bartz
start = perf_counter()
bart = gbart(train.x, train.y.squeeze(0), ntree=n_tree, printevery=10, seed=keys.pop())
end = perf_counter()

W0202 12:42:52.456641    1279 hlo_rematerialization.cc:3204] Can't reduce memory use below 11.42GiB (12266248029 bytes) by rematerialization; only reduced to 12.37GiB (13285972825 bytes), down from 12.38GiB (13287932865 bytes) originally


..........
Iteration 10/1100, grow prob: 54%, move acc: 37%, fill: 6% (burnin)
..........
Iteration 20/1100, grow prob: 53%, move acc: 36%, fill: 6% (burnin)
..........
Iteration 30/1100, grow prob: 54%, move acc: 34%, fill: 6% (burnin)
..........
Iteration 40/1100, grow prob: 54%, move acc: 34%, fill: 6% (burnin)
..........
Iteration 50/1100, grow prob: 54%, move acc: 34%, fill: 6% (burnin)
..........
Iteration 60/1100, grow prob: 54%, move acc: 33%, fill: 6% (burnin)
..........
Iteration 70/1100, grow prob: 54%, move acc: 33%, fill: 6% (burnin)
..........
Iteration 80/1100, grow prob: 54%, move acc: 33%, fill: 6% (burnin)
..........
Iteration 90/1100, grow prob: 54%, move acc: 32%, fill: 6% (burnin)
..........
Iteration 100/1100, grow prob: 54%, move acc: 33%, fill: 6% (burnin)
..........
Iteration 110/1100, grow prob: 53%, move acc: 32%, fill: 6%
..........
Iteration 120/1100, grow prob: 53%, move acc: 32%, fill: 6%
..........
Iteration 130/1100, grow prob: 53%, move acc: 32%, fill:

Interpretation of the printout:
* grow prob = the fraction of trees that bart tried to add leaves to, rather than remove leaves from
* move acc = in the last iteration, the fraction of attempted tree changes that were kept
* fill = how much of the fixed tree space is used; if it's more that ~50% the trees are too deep, increase the number of trees

The fractions refer to the state of the trees after the last iteration, they are not averaged over multiple iterations.

A low acceptance means that the trees are changing very slowly.

The next cell computes the predictions.

In [7]:
from jax import numpy as jnp

# compute predictions
yhat_test = bart.predict(test.x) # posterior samples, n_samples x n_test
yhat_test_mean = jnp.mean(yhat_test, axis=0) # posterior mean point-by-point
yhat_test_var = jnp.var(yhat_test, axis=0) # posterior variance point-by-point

# RMSE
rmse = jnp.sqrt(jnp.mean(jnp.square(yhat_test_mean - test.y)))
expected_error_variance = jnp.mean(jnp.square(bart.sigma))
expected_rmse = jnp.sqrt(jnp.mean(yhat_test_var + expected_error_variance))
avg_sigma = jnp.sqrt(expected_error_variance)

print(f'total sdev: {jnp.std(train.y):#.2g}')
print(f'error sdev: {sigma:#.2g}')
print(f'RMSE: {rmse:#.2g}')
print(f'expected RMSE: {expected_rmse:#.2g}')
print(f'model error sdev: {avg_sigma:#.2g}')
print(f'time: {(end - start) / 60:#.2g} min')

total sdev: 1.0
error sdev: 0.10
RMSE: 0.44
expected RMSE: 0.51
model error sdev: 0.42
time: 8.2 min


The RMSE can at best be as low as the error standard deviation used to generate the data.