In [None]:
# show plots inline in the notebook
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, Math
import zeus
import seaborn as sns
from sklearn.metrics import adjusted_rand_score
import numpy as np
import zeus
from sklearn.cluster import AgglomerativeClustering
from data_helper import load_umap_newsgroups
from tqdm import tqdm
plt.style.use("seaborn-darkgrid")

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

X_news, y_news = load_umap_newsgroups(sampling=10, n_components=10)


In [None]:
# Rand Scores and Distance Thresholds
def get_rand_score(distance_threshold, X, y, eps=1e-3):
    predicted = AgglomerativeClustering(distance_threshold=distance_threshold, n_clusters=None).fit_predict(X)
    return eps + adjusted_rand_score(predicted, y)


distance_thresholds = np.arange(0,50, 0.2)
scores = [get_rand_score(distance_threshold, X_news, y_news) for distance_threshold in distance_thresholds]


In [None]:
# Run MCMC to sample from distribution over learned distance thresholds

def get_distance_threshold(param):
    return np.exp(-param)

def logprior(param):
    return 0

def loglikelihood(param, X, y):
    return get_rand_score(get_distance_threshold(param), X, y)

def logposterior(param, X, y):
    return logprior(param) + loglikelihood(param, X, y)


ndim = 1
nwalkers = 4
nsteps = 100
start = np.random.rand(nwalkers, ndim)

sampler = zeus.EnsembleSampler(nwalkers, ndim, logposterior, args=[X_news, y_news], verbose=False)
sampler.run_mcmc(start, nsteps)
chain = sampler.get_chain(flat=True, discard=nsteps//2)


In [None]:
%matplotlib inline


plt.figure(figsize=(25,10))

plt.subplot(1,2,1)
plt.hist([get_distance_threshold(c) for c in chain], density=True)
plt.title("Samples from {}_n Fit Over 20 Newsgroups Dataset".format(chr(956)), fontsize=20)
plt.xlabel("Distance Threshold Used By Single Linkage", fontsize=20)
plt.ylabel("Normalized Proportion of Samples From {}_n".format(chr(956)), fontsize=20)

plt.subplot(1,2,2)
plt.plot(distance_thresholds, scores)
plt.title("Distance Threshold and Adjusted Rand Score Over 20 Newsgroups Dataset", fontsize=20)
plt.xlabel("Distance Threshold Used By Single Linkage", fontsize=20)
plt.ylabel("Adjusted Rand Score Over 20 Newsgroups Dataset", fontsize=20)