In [2]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow.keras.backend as K

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from scipy.optimize import minimize, fmin_l_bfgs_b, fmin_ncg
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.initializers import GlorotUniform

from bore.datasets import make_regression_dataset, make_classification_dataset
# from bore.decorators import negate, unbatch, make_value_and_gradient_fn, numpy_o

In [3]:
K.set_floatx("float64")

num_samples = 15
num_features = 1

num_index_points = 512
xmin, xmax = -1.0, 2.0
X_grid = np.linspace(xmin, xmax, num_index_points).reshape(-1, num_features)

noise_variance = 0.2
quantile = 1/3

seed = 42
random_state = np.random.RandomState(seed)

In [4]:
def latent(x):
#     return (6.0*x-2.0)**2 * np.sin(12.0*x - 4.0)
    return np.sin(3.0*x) + x**2 - 0.7*x

In [6]:
load_observations = make_regression_dataset(latent)
X, y = load_observations(num_samples=num_samples,
                         num_features=num_features,
                         noise_variance=noise_variance,
                         x_min=xmin, x_max=xmax,
                         random_state=random_state)

In [7]:
import heapq

In [18]:
y

array([-0.57487186,  1.65858566,  0.27736031,  1.35581601,  0.58882881,
        0.55933779,  1.01513697,  0.14056748, -0.48741114,  0.78465262,
        0.02328776,  0.13607386, -0.1462268 ,  0.75616402,  0.40370507])

In [19]:
lst = list(y)
lst

[-0.5748718631108038,
 1.6585856572472957,
 0.27736030726910965,
 1.3558160137662398,
 0.5888288120705181,
 0.5593377873729888,
 1.0151369716187753,
 0.14056748231143446,
 -0.4874111361861066,
 0.7846526153931701,
 0.023287759286624216,
 0.13607385875184191,
 -0.14622679715264675,
 0.7561640204306614,
 0.4037050656344667]

In [20]:
heapq.heapify(lst)

In [21]:
lst

[-0.5748718631108038,
 -0.4874111361861066,
 -0.14622679715264675,
 0.14056748231143446,
 0.023287759286624216,
 0.13607385875184191,
 0.4037050656344667,
 1.6585856572472957,
 1.3558160137662398,
 0.7846526153931701,
 0.5888288120705181,
 0.27736030726910965,
 0.5593377873729888,
 0.7561640204306614,
 1.0151369716187753]

In [22]:
heapq.heappop(lst)

-0.5748718631108038

In [28]:
13/3

4.333333333333333

In [33]:
def top_k(heap, k):
    
    for i in range(k):
        yield heapq.heappop(heap)

In [34]:
list(top_k(lst, 5))

IndexError: index out of range

In [38]:
lst = []

for (u, v) in zip(X, y):

    heapq.heappush(lst, (v, u))

    print(lst)

[(-0.5748718631108038, array([-0.44543663]))]
[(-0.5748718631108038, array([-0.44543663])), (1.6585856572472957, array([1.90875388]))]
[(-0.5748718631108038, array([-0.44543663])), (1.6585856572472957, array([1.90875388])), (0.27736030726910965, array([1.32539847]))]
[(-0.5748718631108038, array([-0.44543663])), (1.3558160137662398, array([1.81849682])), (0.27736030726910965, array([1.32539847])), (1.6585856572472957, array([1.90875388]))]
[(-0.5748718631108038, array([-0.44543663])), (0.5888288120705181, array([1.68448205])), (0.27736030726910965, array([1.32539847])), (1.6585856572472957, array([1.90875388])), (1.3558160137662398, array([1.81849682]))]
[(-0.5748718631108038, array([-0.44543663])), (0.5888288120705181, array([1.68448205])), (0.27736030726910965, array([1.32539847])), (1.6585856572472957, array([1.90875388])), (1.3558160137662398, array([1.81849682])), (0.5593377873729888, array([0.79369994]))]
[(-0.5748718631108038, array([-0.44543663])), (0.5888288120705181, array([1

In [51]:
b = list(lst)
c, d = list(zip(*[heapq.heappop(b) for i in range(5)]))

In [56]:
y_low = np.hstack(c)
y_low

array([-0.57487186, -0.48741114, -0.1462268 ,  0.02328776,  0.13607386])

In [57]:
X_low = np.vstack(d)
X_low

array([[-0.44543663],
       [-0.41205141],
       [-0.1859529 ],
       [-0.02400901],
       [ 0.16603187]])

In [58]:
b

[(0.14056748231143446, array([-0.73452249])),
 (0.5888288120705181, array([1.68448205])),
 (0.27736030726910965, array([1.32539847])),
 (0.7846526153931701, array([-0.86431813])),
 (0.7561640204306614, array([1.48621253])),
 (0.5593377873729888, array([0.79369994])),
 (0.4037050656344667, array([0.07025998])),
 (1.6585856572472957, array([1.90875388])),
 (1.0151369716187753, array([1.76562271])),
 (1.3558160137662398, array([1.81849682]))]