In [None]:
import numpy as np
import holoviews as hv
from sklearn.mixture import BayesianGaussianMixture
import pandas as pd
import logging
from scipy.spatial.distance import cdist
from scipy.optimize import linear_sum_assignment
from dpm.dpgmm import WeightedDPGMM
hv.extension("bokeh")

## compare runtime
generate some data

In [None]:
def make_data():
    num_clusters = 10
    N = 100000

    x_means = 20 * np.random.rand(1, num_clusters, 2) - 10
    y = np.random.randint(num_clusters, size=N)
    x = .08 * np.random.randn(N, 1, 2)

    temp = np.zeros((N, num_clusters, 1))
    temp[np.arange(N), y, :] = 1

    x = (x + x_means * temp).sum(1)

    x_df = pd.DataFrame(x, columns=["x", "y"])
    x_df["true_cluster"] = y.astype(str)
    return x_df

setup timers

In [None]:
from tqdm import notebook as tqdm
from timeit import timeit

tol = 1e-6
num_iters = 1000
cov_type = "diag"

def run_model(x,w,seed):
    model = WeightedDPGMM(n_components=20, verbose=0, max_iter=num_iters, tol=tol, covariance_type=cov_type,random_state=seed)
    labels = model.fit_predict(x, sample_weight=w)
    
def run_model_unweighted(x,seed):
    model = BayesianGaussianMixture(n_components=20, verbose=0, max_iter=num_iters, tol=tol, covariance_type=cov_type,random_state=seed)
    labels = model.fit_predict(x)
    
def time_model(x,w=None,kind="weighted",number = 1, seed = None):
    
    if kind == "weighted":
        dt =  timeit(lambda: run_model(x,w,seed),number=number)
    else:
        dt =  timeit(lambda: run_model_unweighted(x,seed),number=number)
        
    return dict(dt = dt/number, kind=kind, size = len(x))
    



run 20 trials for different number of input samples with bin size of 1.

In [None]:
out = []

r = 0

# seed = np.random.randint(1000)
# np.random.seed(seed)

seed=None

for i in tqdm.trange(20):
    
    
    for num_points in tqdm.tqdm(np.logspace(3,4,30),leave=False):
        x_df = make_data()



        x_sample = x_df.sample(int(num_points))
        x = x_sample.loc[:,["x","y"]].values

        o = time_model(x,w=None,kind="unweighted",seed= seed)
        o["og_size"] = int(num_points)
        out.append(o)

        x_df_rounded = x_sample.round(r).groupby(["x","y"]).size().to_frame("weight").reset_index()
        x = x_df_rounded.loc[:,["x","y"]].values
        w = x_df_rounded.loc[:,"weight"].values

        o = time_model(x,w=w,kind="weighted",seed = seed)
        o["og_size"] = int(num_points)
        out.append(o)
        
        

plot distribution over trials

As we expect each iteration is proportional to number of samples. So if we bin the input then we get a speed up. More input points and larger bin sizes result in more gains (with loss of accuracy obvi)

In [None]:
out_df = pd.DataFrame(out)
temp = out_df.groupby(["kind","og_size"])["dt"].agg(["mean","std"]).reset_index().rename(columns=dict(mean="mean_runtime",std="std_runtime"))
# temp["std_dt"] *= 3
def plot_method(kind):
    return hv.Spread(temp.query(f"kind=='{kind}'"),["og_size"],["mean_runtime","std_runtime"]).opts(line_color=None,alpha=.3)*\
           hv.Curve(temp.query(f"kind=='{kind}'"),["og_size"],["mean_runtime"]).opts(logx=True,logy=False,width=400,height=400,line_width=2)*\
           hv.Points(out_df.query(f"kind=='{kind}'"),["og_size","dt"]).opts(width=400,height=400,size=5,alpha=.5, line_color=None,padding=.1)
    
dt_plot = hv.NdOverlay([(m,plot_method(m)) for m in ['weighted','unweighted']],kdims=["method"]).opts(legend_position="top")

# out_df["diff"]
p = dt_plot+\
hv.Points(out_df.query("kind=='weighted'").rename(columns=dict(size="bins")),["og_size","bins"]).opts(color="green",width=400,height=400,size=5,cmap="Category10",alpha=.5, line_color=None,logx=True,logy=False,padding=.1) +\
hv.Points(x_df.sample(10000), kdims=["x", "y"], vdims=["true_cluster"], label="og").opts(color="true_cluster",cmap="Category20", size=5, alpha=.1, width=400, height=400, show_legend=False,padding=.2) + \
hv.Points(x_df.round(r).groupby(["x","y"]).size().to_frame("weight").reset_index(), kdims=["x", "y"], vdims=["weight"], label="binned").opts(color="weight",cmap="fire",logz=True, colorbar=True, size=10, width=470, height=400,padding=.2, show_legend=False)
p.cols(2)