In [8]:
import numpy as np
import matplotlib.pyplot as plt
import random
import sys

sys.path.insert(1, '../')

In [9]:
def point_on_triangle(pt1, pt2, pt3):
    """
    Random point on the triangle with vertices pt1, pt2 and pt3.
    """
    x, y = random.random(), random.random()
    q = abs(x - y)
    s, t, u = q, 0.5 * (x + y - q), 1 - 0.5 * (q + x + y)
    return (
        s * pt1[0] + t * pt2[0] + u * pt3[0],
        s * pt1[1] + t * pt2[1] + u * pt3[1],
    )


overlap_rate = 0.3
end = 1
traingle_length = 10
overlap = end * overlap_rate
n_samples = 5000


class_1 = np.array([point_on_triangle((0, (end - overlap)/2), (traingle_length, -overlap), (traingle_length, end)) for i in range(int(n_samples/2))])
class_2 = np.array([point_on_triangle((0, (end - overlap)/ -2), (traingle_length, overlap), (traingle_length, -end)) for i in range(int(n_samples/2))])

plt.scatter(class_1[:, 0], class_1[:, 1])
plt.scatter(class_2[:, 0], class_2[:, 1])

In [10]:
x= np.concatenate([class_1, class_2])
y = np.concatenate([np.zeros(class_1.shape[0]), np.ones(class_2.shape[0])])

In [11]:
from disentanglement.models.architectures import get_blobs_dropout_architecture
from disentanglement.models.entropy_models import mutual_information, expected_entropy, \
    train_entropy_model
from disentanglement.models.multi_head_models import uncertainty, \
    train_disentangle_model
from disentanglement.settings import BATCH_SIZE, NUM_SAMPLES, TEST_MODE
from tqdm import tqdm
from sklearn.metrics import accuracy_score


epochs = 50
architecture_func = get_blobs_dropout_architecture
n_classes = 2


disentangle_model = train_disentangle_model(architecture_func, x, y, n_classes,
                                                epochs=epochs)
entropy_model = train_entropy_model(architecture_func, x, y, n_classes, epochs=epochs)


In [22]:

res = 0.5
min_x, max_x = [-2, -2 + res] , [15, 2 + res]

xx, yy = np.meshgrid(np.arange(min_x[0], max_x[0], res), np.arange(min_x[1], max_x[1], res))
domain = np.c_[xx.ravel(), yy.ravel()]


In [23]:
pred_mean, pred_ale_std, pred_epi_std = disentangle_model.predict(domain, batch_size=BATCH_SIZE)
entropy_preds = entropy_model.predict_samples(domain, num_samples=NUM_SAMPLES, batch_size=BATCH_SIZE)

gauss_logits_aleatoric = uncertainty(pred_ale_std).reshape(xx.shape)
gauss_logits_epistemic = uncertainty(pred_epi_std).reshape(xx.shape)

it_aleatoric = expected_entropy(entropy_preds).reshape(xx.shape)
it_epistemic = mutual_information(entropy_preds).reshape(xx.shape)

In [30]:
import matplotlib.pyplot as plt
import matplotlib.pylab as pl
from matplotlib.colors import ListedColormap
cmap = pl.cm.binary
my_cmap = cmap(np.arange(cmap.N))
#my_cmap[:, 0] = 0.0
my_cmap[:, -1] = 0.7
my_cmap = ListedColormap(my_cmap)


fig, axes =  plt.subplots(ncols=2, nrows=1, figsize=(8, 4), squeeze=False)    
ax_ale = axes[0][0]
ax_epi = axes[0][1]

cf_ale = ax_ale.contourf(xx, yy, it_aleatoric, antialiased=True)
ax_ale.scatter(x[:, 0], x[:, 1], c=y, cmap=my_cmap, alpha=0.05)
ax_ale.get_xaxis().set_ticks([])
ax_ale.get_yaxis().set_ticks([])
ax_ale.autoscale(False)

cf_epi = ax_epi.contourf(xx, yy, it_epistemic, antialiased=True)
ax_epi.scatter(x[:, 0], x[:, 1], c=y, cmap=my_cmap, alpha=0.05)
ax_epi.get_xaxis().set_ticks([])
ax_epi.get_yaxis().set_ticks([])
ax_epi.autoscale(False)

ax_ale.set_title("Aleatoric uncertainty")
ax_epi.set_title("Epistemic uncertainty")
# ax_ale.set_title("Information Theoretic MC-Dropout")


# ax_ale = axes[0][1]
# ax_epi = axes[1][1]

# cf_ale = ax_ale.contourf(xx, yy, gauss_logits_aleatoric, antialiased=True)
# ax_ale.scatter(x[:, 0], x[:, 1], c=y, cmap=my_cmap, alpha=0.05)
# ax_ale.get_xaxis().set_ticks([])
# ax_ale.get_yaxis().set_ticks([])
# ax_ale.autoscale(False)

# cf_epi = ax_epi.contourf(xx, yy, gauss_logits_epistemic, antialiased=True)
# ax_epi.scatter(x[:, 0], x[:, 1], c=y, cmap=my_cmap, alpha=0.05)
# ax_epi.get_xaxis().set_ticks([])
# ax_epi.get_yaxis().set_ticks([])
# ax_epi.autoscale(False)

# ax_ale.set_ylabel("Aleatoric")
# ax_epi.set_ylabel("Epistemic")
# ax_ale.set_title("Gaussian logits MC-Dropout")

fig.tight_layout()
plt.savefig('../figures/it_traingles.pdf')
plt.show()