In [1]:
import numpy as np
import pickle as pkl
import plotly.graph_objects as go

In [2]:
def get_beta_params(low, high, middle):
    """
    Determines Beta parameters such that when sampling X ~ Beta(a, b)
    and scaling alpha = low + (high - low) * X, we have P(alpha < 0.5) ≈ middle.
    Here we fix a = 1 and solve for b.

    Parameters:
      low    : lower bound of the target range.
      high   : upper bound of the target range.
      middle : desired probability that alpha is less than 0.5.

    Returns:
      (a, b) : tuple of Beta parameters.
    """
    a = 1.0
    q = (0.5 - low) / (high - low)
    b = np.log(1 - middle) / np.log(1 - q)
    return a, b


def sample_noise_mask(T,
                      p_clean_to_noisy=0.01,
                      p_noisy_to_clean=0.1):
    """
    Returns a boolean mask of length T.  True = noisy state, False = clean.
    The expected run‐length in noisy state is ~1/p_noisy_to_clean,
    and the long‐run fraction of time spent in noisy state is
        π_noisy = p_clean_to_noisy / (p_clean_to_noisy + p_noisy_to_clean)
    """
    state = False  # start clean
    mask = np.zeros(T, dtype=bool)
    for t in range(T):
        if not state:
            # clean → noisy?
            if np.random.rand() < p_clean_to_noisy:
                state = True
        else:
            # noisy → clean?
            if np.random.rand() < p_noisy_to_clean:
                state = False
        mask[t] = state
    return mask


def generate_locked_bursty_trace(target,
                                 low=0.05, high=1.0, middle=0.8,
                                 noise_level_clean=0.0,
                                 noise_level_noisy=0.6,
                                 p_c2n=0.005, p_n2c=0.02):
    """
    Same as before, but in each noisy segment we pick a single wrong class j
    and blend (1-noise_level)*one_hot(j) + noise_level*Dirichlet_noise,
    so argmax stays j throughout the burst.
    """
    T, K = target.shape
    # 1) sample the clean/noisy mask
    noisy_mask = sample_noise_mask(T, p_c2n, p_n2c)

    # 2) precompute your Beta->Dirichlet concentration if you like
    a, b = get_beta_params(low, high, middle)

    noisy_trace = []
    t = 0
    while t < T:
        if not noisy_mask[t]:
            # clean: just keep original p exactly
            p = target[t]
            lvl = noise_level_clean
            # fresh Dirichlet noise around uniform
            alpha_vec = np.ones(K) * np.random.beta(a, b)
            noise = np.random.dirichlet(alpha_vec)
            perturbed = (1 - lvl) * p + lvl * noise
            if np.any(np.isnan(perturbed)):
                    perturbed = p
            noisy_trace.append(perturbed / perturbed.sum())
            t += 1
        else:
            # find the length of this noisy run
            start = t
            while t < T and noisy_mask[t]:
                t += 1
            end = t

            # pick one wrong class j for the _whole_ run
            true_label = target[start].argmax()
            # choose j != true_label
            candidates = list(range(K))
            candidates.remove(true_label)
            j = np.random.choice(candidates)

            # now fill in each time‐step in [start, end)
            for _ in range(start, end):
                # draw fresh Dirichlet around uniform base (or you can weight it)
                alpha_vec = np.ones(K) * np.random.beta(a, b)
                noise = np.random.dirichlet(alpha_vec)

                # mix with constant one_hot(j)
                lvl = noise_level_noisy
                one_hot_j = np.eye(K)[j]
                perturbed = (1 - lvl) * one_hot_j + lvl * noise
                # it already sums to 1, so no need to renormalize
                if np.any(np.isnan(perturbed)):
                    perturbed = one_hot_j
                noisy_trace.append(perturbed)

    return np.stack(noisy_trace)


def visualize_traces(x, y, z, names=None):
    if names is None:
        names = ['Original', 'Argmax', 'Reconstructed']
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=list(range(len(x))), y=x, mode='lines', name=names[0],
                             line=dict(color='blue', dash='solid')))
    fig.add_trace(go.Scatter(x=list(range(len(y))), y=y, mode='lines', name=names[1],
                             line=dict(color='red', dash='dash')))
    fig.add_trace(go.Scatter(x=list(range(len(z))), y=z, mode='lines', name=names[2],
                             line=dict(color='green', dash='dot')))
    return fig

In [30]:
data_path = r"../data/pickles/ava_unified.pkl"
with open(data_path, "rb") as f:
    data = pkl.load(f)

target, source = data['target'], data['stochastic']

In [31]:
import sklearn.metrics as metrics

np.mean([metrics.accuracy_score(np.argmax(t, axis=1), np.argmax(s, axis=1)) for t, s in zip(target, source)])

0.9824006952353793

In [101]:
better_bursty_noisy_trace = generate_locked_bursty_trace(target[0], low=0.05, high=3, middle=0.5,
                                                         noise_level_clean=0.5,
                                                         noise_level_noisy=0.5,
                                                         p_c2n=0.5,
                                                         p_n2c=0.9)

In [102]:
visualize_traces(np.argmax(target[0], axis=1), np.argmax(source[0], axis=1),
                 np.argmax(better_bursty_noisy_trace, axis=1))

In [32]:
better_bursty_dataset = [np.array(generate_locked_bursty_trace(t, low=0.05, high=3, middle=0.5,
                                                         noise_level_clean=0.5,
                                                         noise_level_noisy=0.5,
                                                         p_c2n=0.5,
                                                         p_n2c=0.9)) for t in target]

In [33]:
np.mean([metrics.accuracy_score(np.argmax(t, axis=1), np.argmax(s, axis=1)) for t, s in zip(target, better_bursty_dataset)])

0.6415021457059498

In [27]:
from scipy.stats import entropy

def average_entropy_data(data):
    return np.mean([np.mean(entropy(x, axis=1)) for x in data])

In [34]:
average_entropy_data(source), average_entropy_data(better_bursty_dataset)

(0.8858933885253342, 2.0193925465282625)

In [35]:
with open("../data/pickles/ava_improved.pkl", "wb") as f:
    pkl.dump({'target': target, 'stochastic': better_bursty_dataset}, f)