In [1]:
import pickle as pkl
import torch
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from scipy.stats import entropy

In [3]:
data_path = r"../data/pickles/50_salads_unified.pkl"

In [4]:
with open(data_path, "rb") as f:
    data = pkl.load(f)

det, sto = data['target'], data['stochastic']

In [5]:
def average_entropy_sequence(seq):
    return np.mean([entropy(x) for x in seq])

def average_entropy_data(data):
    return np.mean([average_entropy_sequence(x) for x in data])

In [6]:
average_entropy_data(det), average_entropy_data(sto)

(0.0, 0.24932393)

In [7]:
def get_beta_params(low, high, middle):
    """
    Determines Beta parameters such that when sampling X ~ Beta(a, b)
    and scaling alpha = low + (high - low) * X, we have P(alpha < 0.5) ≈ middle.
    Here we fix a = 1 and solve for b.

    Parameters:
      low    : lower bound of the target range.
      high   : upper bound of the target range.
      middle : desired probability that alpha is less than 0.5.

    Returns:
      (a, b) : tuple of Beta parameters.
    """
    a = 1.0
    q = (0.5 - low) / (high - low)
    b = np.log(1 - middle) / np.log(1 - q)
    return a, b

def sample_concentration_parameter(low=0.05, high=1, middle=0.8):
    """
    Samples a concentration parameter α in [low, high] such that
    P(α < 0.5) ≈ 0.8.
    """
    a, b = get_beta_params(low, high, middle)
    x = np.random.beta(a, b)
    # Scale to [low, high]
    alpha = low + (high - low) * x
    return alpha

def add_dirichlet_noise(p, alpha=0.1, noise_level=0.1):
    """
    p: original probability vector (e.g., one-hot)
    alpha: concentration parameter for the Dirichlet distribution.
           Lower alpha makes the Dirichlet sample more "peaky".
    noise_level: mixing coefficient for the noise.
    """
    # Generate a Dirichlet noise vector.
    noise = np.random.dirichlet(np.ones_like(p) * alpha)
    # Mix the original distribution with the noise.
    perturbed = (1 - noise_level) * p + noise_level * noise
    # Ensure it sums to 1.
    perturbed /= perturbed.sum()
    return perturbed

In [None]:
for trace in det:
    alphas = [sample_concentration_parameter() for _ in range(trace.shape[0])]