In [None]:
from pathlib import Path
from typing import NamedTuple

from fractopo import Network
from fractopo.general import read_geofile, JOBLIB_CACHE
from joblib import Parallel, delayed
from fractopo.analysis.length_distributions import determine_fit, calculate_exponent
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import ticker
import pandas as pd
import seaborn as sns

In [None]:
trace_lengths_1_10 = pd.read_csv("../outputs/networks/1_10/trace_lengths.csv")[
    "lengths"
].values
trace_lengths_1_20k = pd.read_csv("../outputs/networks/1_20000/trace_lengths.csv")[
    "lengths"
].values
trace_lengths_1_200k = pd.read_csv(
    "../outputs/networks/1_200000_int/trace_lengths.csv"
)["lengths"].values

In [None]:
class FitBase(NamedTuple):
    xmin: float
    data: np.ndarray
    alpha: float


def determine_censoring_cut_off_fit(censoring_cut_off, lengths):
    censor_cut_off_lengths = lengths[lengths < censoring_cut_off]
    fit = determine_fit(censor_cut_off_lengths)
    return fit.xmin, fit.data, fit.alpha
    # fit_base = FitBase(xmin=fit.xmin, data=fit.data, alpha=fit.alpha)
    # return fit_base.xmin, fit_


@JOBLIB_CACHE.cache
def resolve_censoring_fits(lengths: tuple, num: int = 20):
    lengths = np.array(lengths)
    censoring_cut_offs = np.linspace(
        start=lengths.max() + 0.001, stop=lengths.min(), num=num
    )
    # fits = []

    # for censoring_cut_off in censoring_cut_offs:
    #     censor_cut_off_lengths = lengths[lengths < censoring_cut_off]
    #     fit = determine_fit(censor_cut_off_lengths)
    #     fits.append(fit)

    fits = Parallel(n_jobs=-1)(
        delayed(determine_censoring_cut_off_fit)(
            censoring_cut_off=censoring_cut_off,
            lengths=lengths,
        )
        for censoring_cut_off in censoring_cut_offs
    )
    assert isinstance(fits, list)

    return tuple(fits), tuple(censoring_cut_offs)

In [None]:
fits_1_10, censoring_cut_offs_1_10 = resolve_censoring_fits(
    tuple(trace_lengths_1_10), num=50
)
fits_1_20k, censoring_cut_offs_1_20k = resolve_censoring_fits(
    tuple(trace_lengths_1_20k), num=50
)
fits_1_200k, censoring_cut_offs_1_200k = resolve_censoring_fits(
    tuple(trace_lengths_1_200k), num=50
)

In [None]:
def visualize_effect_of_censoring(censoring_cut_offs, fits, suptitle, lengths):

    censoring_cut_offs = np.array(censoring_cut_offs)
    fits = [FitBase(xmin=fit[0], data=fit[1], alpha=fit[2]) for fit in fits]

    exponents = np.array([calculate_exponent(fit.alpha) for fit in fits])
    cut_offs = [fit.xmin for fit in fits]

    def _resolve_mask(lengths, censoring_cut_off, cut_off):
        lengths_below_censoring = lenghts < censoring_cut_off
        lengths_above_cut_off = lengths > cut_off
        return lengths_below_censoring & lengths_above_cut_off

    cut_off_proportions = np.array(
        [1 - (sum(fit.data > fit.xmin) / len(lengths)) for fit in fits]
    )

    # Plotting
    with sns.plotting_context("paper"):
        fig, axes = plt.subplots(1, 3, figsize=(8.23, 2.5))

        ax_1 = axes[0]
        sns.scatterplot(ax=ax_1, x=censoring_cut_offs, y=exponents)
        # ax_1.scatter(x=censoring_cut_offs, y=exponents)
        ax_1.set_title("Censoring cut-off vs.\npower-law exponent")
        ax_1.set_ylabel("Power-law exponent")
        y_max = -1.0
        y_min = max([-4.5, exponents[~np.isnan(exponents)].min()]) - 0.5
        ax_1.set_ylim(y_min, y_max)
        ax_1.vlines(censoring_cut_offs[exponents < y_min], ymin=-5, ymax=-4.9)
        ax_1.vlines(censoring_cut_offs[exponents > y_max], ymin=1.1, ymax=0)
        # ax_1.yaxis.set_major_locator(ticker.MaxNLocator(integer=False))

        ax_2 = axes[1]
        sns.scatterplot(ax=ax_2, x=censoring_cut_offs, y=cut_offs)
        # ax_2.scatter(x=censoring_cut_offs, y=cut_offs)
        ax_2.set_title("Censoring cut-off vs.\ntruncation cut-off")
        ax_2.set_ylabel("Truncation cut-off [$m$]")

        ax_3 = axes[2]
        sns.scatterplot(ax=ax_3, x=censoring_cut_offs, y=cut_off_proportions)
        # ax_3.scatter(x=censoring_cut_offs, y=cut_off_proportions)
        ax_3.set_title("Censoring cut-off vs.\ncut-off proportion")
        ax_3.set_ylabel("Cut-off proportion")

        for ax in axes:
            ax.set_xlabel("Censoring cut-off [$m$]")

        fig.subplots_adjust(wspace=0.5)
        fig.suptitle(suptitle, x=0.04, y=0.5, rotation=90)
    return fig, axes

In [None]:
for fits, suptitle, censoring_cut_offs, lengths in zip(
    [fits_1_10, fits_1_20k, fits_1_200k],
    ["1:10", "1:20k", "1:200k"],
    [censoring_cut_offs_1_10, censoring_cut_offs_1_20k, censoring_cut_offs_1_200k],
    [trace_lengths_1_10, trace_lengths_1_20k, trace_lengths_1_200k],
):
    fig, axes = visualize_effect_of_censoring(
        fits=fits,
        suptitle=suptitle,
        censoring_cut_offs=censoring_cut_offs,
        lengths=lengths,
    )