In [None]:
import numpy as np
from scipy.optimize import curve_fit
from scipy.stats import binned_statistic
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import seaborn as sns

In [None]:
%load_ext autoreload
%autoreload 2
from latentrees import *

In [None]:
def param_freezer(func, *args, **kwargs):
    def wrapper(x):
        return func(x, *args, **kwargs)
    return wrapper

In [None]:
runtime = analyses()
params = [0.3,1,1.05]
params = [2]

#for param in params:
    #runtime.append_model(L=50, distribution = param_freezer(lambda node, param: np.clip(rng.integers(node-1-np.sqrt(3)*np.power(abs(node),param), node+1+np.sqrt(3)*np.power(abs(node),param)), -1e15, 1e15), param), name="{:.2f}".format(param))
    #runtime.append_model(L=15, distribution = param_freezer(lambda node, param: np.clip(rng.normal(node, np.power(abs(node),param)), -1e15, 1e15), param), name="{:.2f}".format(param))
    #runtime.append_model(nl=param, L=50, name="negative_binom_{:d}".format(param))
    #runtime.append_model(L=25, distribution = param_freezer(lambda node, param: np.clip(rng.integers(node-1-param*np.sqrt(3)*abs(node), node+1+param*np.sqrt(3)*abs(node)), -1e15, 1e15), param), name="{:.2f}".format(param))


runtime.append_model(L=50, name="negative_binom_{:d}".format(1))

#runtime.append_model(L=50, distribution = lambda node: np.clip(rng.normal(node, abs(node)), -1e15, 1e15), name="gaus_scaling")
print(runtime)
runtime.run()

In [None]:
moi_index = "negative_binom_1" #model of interest
if moi_index not in runtime:
    raise ValueError(f"{moi_index} not available")
layers = runtime[moi_index].layers
L = runtime[moi_index].L
nl = runtime[moi_index].nl

In [None]:
fig = plt.figure()

for l in range(1, L+1, round(L/4)):
    cnts = layers[l].sorted_nodes
    #cnts = cnts[cnts>0]
    #freqs = np.unique(cnts, return_counts=True)[1]
    #freqs = freqs / np.sum(freqs)
    freqs = cnts
    x = np.linspace(1, len(freqs), len(freqs))
    plt.plot(x, np.sort(freqs)[::-1]/np.sum(freqs), marker="o", ms=20, lw=10, alpha=0.2, label=l)

plt.plot(x, x**-1, color="gray", lw=10, ls="--")
plt.legend()
plt.xscale("log")
plt.yscale("log")
plt.xlabel("i")
plt.ylabel("fi")
fig.show()

In [None]:
layer_of_interest = runtime[moi_index].layers[10]
cnts = layer_of_interest.sorted_nodes
#cnts = np.abs(cnts)
#cnts = cnts[cnts<1e15]

freqs = np.unique(cnts, return_counts=True)[1]
freqs = freqs / np.sum(freqs)


cnts = cnts/cnts.sum()
cnts = np.sort(cnts)[::-1]

x = np.linspace(1, len(cnts), len(cnts))
xf = np.linspace(1, len(freqs), len(freqs))

fig = go.Figure()

fig.add_trace(go.Scatter(x=x, y=cnts, marker=dict(symbol="0", size=20, color="blue"), mode="markers+lines", line_width=10, name="", showlegend=False))
fig.add_trace(go.Scatter(x=xf, y=freqs, marker=dict(symbol="0", size=20, color="green"), mode="markers+lines", line_width=10, name="", showlegend=False))



fig.add_trace(go.Scatter(x=x, y=1/x, line_width=10, line_dash="dash",name="x^-1"))

fit_func = lambda x, C, gamma: C * np.power(x, - gamma)

popt, pcov = curve_fit(fit_func, x[20:15000], cnts[20:15000])
fig.add_trace(go.Scatter(x=x, y=fit_func(x, *popt), line_width=10, line_dash="longdash", name=f"C*x^-{round(popt[1],3)}"))

popt, pcov = curve_fit(fit_func, xf[1:], freqs[1:])
fig.add_trace(go.Scatter(x=xf, y=fit_func(xf, *popt), line_width=10, line_dash="longdash", name=f"C*x^-{round(popt[1],3)}"))

#dd = np.diff(np.diff(cnts))
#mask = np.argwhere((dd[1:]*dd[:-1]<0)).ravel()
#fig.add_trace(go.Scatter(x=x[mask],y=cnts[mask],  name=f"flexes", mode="markers"))


fig.update_xaxes(type="log", title="rank")
fig.update_yaxes(type="log", exponentformat="e", title="leaf count")
fig.update_layout(title=moi_index, titlefont_size=20)

In [None]:
layer_of_interest = runtime[moi_index].layers[-1]
cnts = layer_of_interest.sorted_nodes
freqs = np.unique(cnts, return_counts=True)[1]
freqs = freqs/freqs.sum()
freqs = np.sort(freqs)[::-1]

x = np.linspace(1, len(freqs), len(freqs))

fig = go.Figure()

fig.add_trace(go.Scatter(x=x, y=freqs, marker=dict(symbol="0", size=20, color="blue"), line_width=10, name="", showlegend=False))
fig.add_trace(go.Scatter(x=x, y=1/x, line_width=10, line_dash="dash",name="x^-1"))

fit_func = lambda x, C, gamma: C * np.power(x, - gamma)

popt, pcov = curve_fit(fit_func, x[20:15000], freqs[20:15000])
fig.add_trace(go.Scatter(x=x, y=fit_func(x, *popt), line_width=10, line_dash="longdash", name=f"C*x^-{round(popt[1],3)}"))

#dd = np.diff(np.diff(cnts))
#mask = np.argwhere((dd[1:]*dd[:-1]<0)).ravel()
#fig.add_trace(go.Scatter(x=x[mask],y=cnts[mask],  name=f"flexes", mode="markers"))


fig.update_xaxes(type="log", title="rank", titlefont_size=30, tickfont_size=25)
fig.update_yaxes(type="log", exponentformat="e", title="f", titlefont_size=30, tickfont_size=25)
fig.update_layout(title=moi_index, titlefont_size=20)

# Last Layer

In [None]:
fig = go.Figure()
leaves = np.array(runtime[moi_index].layers[-1].nodes)
leaves = leaves[abs(leaves) < 1e15]
fig.add_trace(go.Histogram(x=leaves, nbinsx=100))

layout=dict(
xaxis=dict(title="leaves", title_font_size=35, tickfont_size=25),
yaxis=dict(tickfont_size=25)
)

fig.update_layout(layout)

## Histogram of distances

In [None]:
import multiprocessing as mp
import gc
def append_error(err):
    print(err)
        
def append_dist(d):
    global distances
    distances.append(d)
    
def measure_func(leaf_A):
    return list(map(lambda leaf_B: abs(leaf_A[1]-leaf_B[1]) if leaf_A[0] < leaf_B[0] else np.nan, enumerate(leaves)))

In [None]:
data = dict()
for model in runtime:
    loi = model.layers[-1]
    N = 500
    if len(loi)>N:
        leaves = np.random.choice(loi.nodes,size=N,replace=False)
    else:
        leaves = loi.nodes
    norm_leaves = max(loi.nodes)
    #print(norm_leaves)

    distances = []
    pool = mp.Pool(4) 
    res = pool.map_async(measure_func, enumerate(leaves), callback=append_dist, error_callback=append_error)
    pool.close()
    pool.join()
    distances = np.ravel(distances)
    #distances = np.ravel(list(map(lambda leaf: abs((leaf-avg_leaves)/norm_leave),enumerate(leaves))))
    #distances=distances/max([np.nanmax(distances),abs(np.nanmin(distances))])
    distances = distances[~np.isnan(distances)]
    #distances = distances[distances>=0]

    data[model.name]=distances
    loi = None
    gc.collect()

### distance vs param

In [None]:
scale_distances = False

fig = go.Figure()

n_leaves = len(leaves)
for param,distances in data.items():
    try:
        if scale_distances:
            distances=distances/max([np.quantile(distances, 0.99),abs(np.nanmin(distances))])
            bins=np.linspace(0,np.quantile(distances, 0.99),15)
        else:
            bins=np.logspace(np.log10(distances[distances>1e-10].min()),np.log10(distances.max()), 10)
        bins, edges = np.histogram(distances, bins=bins, density=True)
        esges = (edges[1:]+edges[:1])/2
        fig.add_trace(go.Scatter(x=edges,y=bins,  marker=dict(size=20), line=dict(width=10), name=param))
    except:
        pass
    
fig.update_layout(xaxis=dict(title="distances", titlefont_size=35, tickfont_size=35, nticks= 5),
                 yaxis=dict(title="pdf", titlefont_size=35,tickfont_size=35, type="log", exponentformat="e", showexponent='all', nticks=4),
                 legend=dict(x=1.01,y=1,borderwidth=0.5,font_size=15,orientation="v"))

if not scale_distances:
    fig.update_xaxes(type="log")
fig.show()
filename = "images/pdf_distances_nbinom_scaling"
if scale_distances:
    filename+="_scaled"
#fig.write_image(f"{filename}.pdf")
#fig.write_html(f"{filename}.html")

### Distance vs layer

In [None]:
fig = go.Figure()

for loi in runtime[-1].layers[::10]:
    N = 500
    if len(loi)>N:
        leaves = np.random.choice(loi.nodes,size=N,replace=False)
    else:
        leaves = loi.nodes
    avg_leaves = loi.median
    
    distances = []
    pool = mp.Pool(2) 
    res = pool.map_async(measure_func, enumerate(leaves), callback=append_dist, error_callback=append_error)
    pool.close()
    pool.join()
    distances = np.ravel(distances)
    distances = distances[~np.isnan(distances)]
    
    n_leaves = len(leaves)

    bins=np.logspace(np.log10(distances[distances>0].min()),np.log10(distances.max()), 15)
    #bins=np.linspace(distances.min(),distances.max(),20)
    bins, edges = np.histogram(distances, bins=bins, density=True)
    esges = (edges[1:]+edges[:1])/2
    fig.add_trace(go.Scatter(x=edges,y=bins,  marker=dict(size=20), line=dict(width=10), name=loi.__repr__().split(",")[0]))

    gc.collect()
    
fig.update_layout(xaxis=dict(title="distances", titlefont_size=35, tickfont_size=35, exponentformat="e", type="log", nticks= 4),
                 yaxis=dict(title="pdf", titlefont_size=35,tickfont_size=35, type="log", exponentformat="e", showexponent='all', nticks=4),
                 legend=dict(x=1.01,y=1,borderwidth=0.5,font_size=15,orientation="v"))

fig.show()
filename = "images/distance_pdf_layers_nbinom"
fig.write_image(f"{filename}.pdf")
fig.write_html(f"{filename}.html")

# Hyperparameters

## gamma

In [None]:
def get_exp(layer, x_limits = (0,-1))->float:
    try:
        layer_of_interest = layer
        cnts = layer_of_interest.sorted_nodes
        #cnts = np.abs(cnts)
        #cnts = cnts[np.abs(cnts)<1e15]
        #cnts = cnts/cnts.sum()
        #cnts = np.sort(cnts)[::-1]
        freqs = np.unique(cnts, return_counts=True)[1]
        freqs = freqs/freqs.sum()
        freqs = np.sort(freqs)[::-1]
        x = np.linspace(1, len(freqs), len(freqs))

        popt, pcov = curve_fit(lambda x, C, gamma: C * np.power(x, - gamma), x[x_limits[0]:x_limits[1]], freqs[x_limits[0]:x_limits[1]])
        return popt[1]
    except:
        return np.nan

In [None]:
exps = list(map(lambda m: get_exp(m.layers[-1]), runtime))

In [None]:
exps_first = list(map(lambda m: get_exp(m.layers[-1], x_limits=(0,100)), runtime))
exps_second = list(map(lambda m: get_exp(m.layers[-1], x_limits=(100,1000)), runtime))
exps_third = list(map(lambda m: get_exp(m.layers[-1], x_limits=(1000,5000)), runtime))

In [None]:
x, xlabel = params, "scaling"
#x, xlabel = np.linspace(1,len(exps),len(exps)), "Layer"

fig = go.Figure()

#fig.add_scatter(x = x, y=exps, error_y=dict(type="data", array=exps_errors, visible=True, width=8, thickness=3), name="exponents", mode="lines", marker=dict(size=10), line=dict(width=10, color="gray"))
fig.add_scatter(x = x, y=exps, name="exponents", mode="lines", marker=dict(size=10), line=dict(width=10, color="gray"))
fig.add_trace(go.Scatter(y=[1,1], x=[min(x)*0.9,max(x)*1.1], name="1", mode="lines", line=dict(width=10, color="blue", dash="dash")))

for exp, name in zip([exps_first, exps_second, exps_third],["first", "second", "third"]):
    fig.add_scatter(x = x, y=exp, name=name, mode="lines", marker=dict(size=10), line=dict(width=10))



fig.update_traces(marker_size=20)
fig.update_layout(xaxis=dict(title=xlabel, exponentformat = 'e', tickfont=dict(size=20), title_font_size=35),
                  yaxis_title="gamma",
                  yaxis=dict(tickfont=dict(size=20), title_font=dict(size=35)),
                 legend=dict(font_size=30, orientation="v", x=0.9, y=1))
fig.show()
filename = "images/exp_scaling_unif_regimes"
#fig.write_image("{}.pdf".format(filename))
#fig.write_html("{}.html".format(filename))

In [None]:
import gc
gc.collect()