In [1]:
import matplotlib as mpl
mpl.rcParams['font.family'] = 'Arial'
mpl.rcParams['text.usetex'] = False
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
import time
from utrees import Baltobot
from treeffuser import Treeffuser
time.time()

# Generate the data
seed = 0
n = 5000
rng = np.random.default_rng(seed=seed)
x = rng.uniform(0, 2 * np.pi, size=n)
z = rng.integers(0, 2, size=n)
y = z * np.sin(x - np.pi / 2) + (1 - z) * np.cos(x) + rng.laplace(scale=x / 30, size=n)

In [2]:
# Fit the models
start_time = time.time()
tfer = Treeffuser(sde_initialize_from_data=True, seed=seed)
tfer.fit(x, y)
tf_train_time = time.time() - start_time
y_tfer = tfer.sample(x, n_samples=1, seed=seed, verbose=True)
tf_time = time.time() - start_time

start_time = time.time()
tber = Baltobot(random_state=seed)
tber.fit(x.reshape(-1, 1), y)
tb_train_time = time.time() - start_time
y_tber = tber.sample(x.reshape(-1, 1))
tb_time = time.time() - start_time

start_time = time.time()
tbtaber = Baltobot(tabpfn=True, random_state=seed)
tbtaber.fit(x.reshape(-1, 1), y)
tbtab_train_time = time.time() - start_time
y_tbtaber = tbtaber.sample(x.reshape(-1, 1))
tbtab_time = time.time() - start_time

  X = _check_array(X)
  y = _check_array(y)
  X = _check_array(X)
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
fig, axes = plt.subplots(nrows=3, figsize=(7, 7), sharex=True, dpi=300);
axes[0].scatter(x, y, s=1, label="Observed data")
axes[0].scatter(x, y_tfer[0, :], s=1, alpha=0.7, label="Treeffuser samples")
axes[0].legend();

axes[1].scatter(x, y, s=1, label="Observed data")
axes[1].scatter(x, y_tber, s=1, alpha=0.7, label="Baltobot samples")
axes[1].legend();

axes[2].scatter(x, y, s=1, label="Observed data")
axes[2].scatter(x, y_tbtaber, s=1, alpha=0.7, label="BaltoboTabPFN samples")
axes[2].legend();
plt.tight_layout();
plt.savefig('wave-demo.png');
plt.close();

plt.figure(dpi=200, figsize=(4,3));
total_time_df = pd.DataFrame.from_dict({'Treeffuser': [tf_time], 'Baltobot': [tb_time], 'BaltoboTabPFN': [tbtab_time]}).T
total_time_df.columns = ['Total']
train_time_df = pd.DataFrame.from_dict({'Treeffuser': [tf_train_time], 'Baltobot': [tb_train_time], 'BaltoboTabPFN': [tbtab_train_time]}).T
train_time_df.columns = ['Training']
time_df = pd.concat([total_time_df, train_time_df], axis=1)
time_df['Sampling'] = time_df['Total'] - time_df['Training']

time_dff = time_df.stack().reset_index()
time_dff.columns = ['Method', 'Task', 'Time']
sns.barplot(data=time_dff, y='Method', x='Time', hue='Task');
plt.ylabel('Method');
plt.xlabel('Time (s)');
plt.tight_layout();
plt.savefig('wave-demo-time.png');
plt.close()
time_dff

Unnamed: 0,Method,Task,Time
0,Treeffuser,Total,6.401965
1,Treeffuser,Training,1.356733
2,Treeffuser,Sampling,5.045232
3,Baltobot,Total,3.175086
4,Baltobot,Training,2.396088
5,Baltobot,Sampling,0.778998
6,BaltoboTabPFN,Total,12.216128
7,BaltoboTabPFN,Training,2.14845
8,BaltoboTabPFN,Sampling,10.067678


In [4]:
lhs, rhs = np.meshgrid(np.linspace(-1, 7, 30), np.linspace(-3,2, 30))
lhsrhs = np.hstack([lhs.reshape(-1, 1), rhs.reshape(-1, 1)])
plt.figure();
plt.scatter(lhsrhs[:, 0], lhsrhs[:, 1])
scores = tber.score_samples(lhs.reshape(-1, 1), rhs.reshape(-1))
plt.close();
plt.figure();
plt.scatter(lhsrhs[:, 0], lhsrhs[:, 1], s=100*np.exp(scores));
plt.close();

In [5]:
Xs = 2 * np.ones((1000, 1))
Ys = np.linspace(-5, 5, 1000)
tb_scores = tber.score_samples(Xs, Ys)
tbtab_scores = tbtaber.score_samples(Xs, Ys)
print(np.exp(tb_scores).sum() * (Ys[1]-Ys[0]), np.exp(tbtab_scores).sum() * (Ys[1]-Ys[0]))
plt.figure(figsize=(4,2), dpi=200);
plt.plot(Ys, np.exp(tb_scores), label='Baltobot');
plt.plot(Ys, np.exp(tbtab_scores), '--', label='BaltoboTabPFN');
plt.xlim(-2, 2);
plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.05));
plt.xlabel('y');
plt.ylabel('pdf at x=2');
plt.tight_layout();
plt.savefig('wave-pdfat2.png');
plt.close();

0.9820769258371294 0.9859963150095391


In [6]:
nP = 500
rng = np.random.default_rng(seed=seed)
XP = rng.uniform(0, 3, size=nP)
YP = rng.poisson(np.sqrt(XP), size=nP)
tfer = Treeffuser(sde_initialize_from_data=True, seed=seed)
tfer.fit(XP, YP)
YP_tfer = tfer.sample(XP, n_samples=1, seed=seed, verbose=True)
tber = Baltobot(random_state=seed)
tber.fit(XP.reshape(-1, 1), YP)
YP_tber = tber.sample(XP.reshape(-1, 1))

  X = _check_array(X)
  y = _check_array(y)
  X = _check_array(X)


In [7]:
dfP = pd.DataFrame(); dfP['x'] = XP; dfP['y'] = YP
s = 8; linewidth=0.3; edgecolor='white'; markercolor='blue';
fig, axes = plt.subplots(figsize=(7,3), ncols=3, sharey=True, dpi=500);
sns.scatterplot(data=dfP, x='x', y='y', s=s, edgecolor=edgecolor, linewidth=linewidth, color=markercolor, ax=axes[0])
axes[0].set_title('Original data');
dfP_tfer = pd.DataFrame(); dfP_tfer['x'] = XP; dfP_tfer['y'] = YP_tfer.ravel()
sns.scatterplot(data=dfP_tfer, x='x', y='y', s=s, edgecolor=edgecolor, linewidth=linewidth, color=markercolor, ax=axes[1])
axes[1].set_title('Treeffuser')
dfP_tber = pd.DataFrame(); dfP_tber['x'] = XP; dfP_tber['y'] = YP_tber
sns.scatterplot(data=dfP_tber, x='x', y='y', s=s, edgecolor=edgecolor, linewidth=linewidth, color=markercolor, ax=axes[2])
axes[2].set_title('Baltobot')
plt.tight_layout();
plt.savefig('poisson-demo.png');
plt.close();