In [1]:
from functools import partial
import time

import numpy as np
from scipy.integrate import quad
from scipy.optimize import root_scalar
from scipy import stats
from scipy.stats import norm, halfnorm, truncnorm

def integrand(block_size, x, m):
    p_z_less_than_mx = truncnorm.cdf(m * x, -m, m)
    pm = block_size * (halfnorm.cdf(m) ** (block_size - 1)) * 2 * norm.pdf(m)

    return p_z_less_than_mx * pm

def scaled_norm_cdf(block_size, x):
    result = quad(
        partial(integrand, block_size, x),
        0,
        np.inf,
        epsabs=1e-9,
    )
    return result[0]

def cdf(x, block_size):
    discrete_mass = 1 / (2 * block_size)
    cont_mass = scaled_norm_cdf(block_size, x) * (block_size - 1) / block_size
    result = discrete_mass + cont_mass
    result = np.where(x < -1, 0, result)
    result = np.where(x >= 1, 1, result)
    return result

def inv_cdf(val, block_size):
    edge_mass  = 1 / (2 * block_size)
    if val <= edge_mass:
        return -1
    if val >= 1 - edge_mass:
        return 1
    def search_fn(x):
        return cdf(x, block_size) - val

    return root_scalar(search_fn, bracket=[-1, 1]).root

In [2]:
def build_code(start, lower_bound, upper_bound, n_steps, bcdf, binv_cdf, lower_bound_is_code_point=True):
    code = [start]
    a = lower_bound if lower_bound_is_code_point else None
    b = start
    prev_midpoint_prob = None
    for _ in range(n_steps):
        if prev_midpoint_prob is None:
            prev_midpoint_prob = bcdf((a + b) / 2) if lower_bound_is_code_point else 0
        prev_mass = bcdf(b) - prev_midpoint_prob

        next_midpoint_prob = bcdf(b) + prev_mass
        if next_midpoint_prob > 1:
            c = 1
        else:
            next_midpoint = binv_cdf(next_midpoint_prob)
            prev_midpoint_prob = next_midpoint_prob

            c = next_midpoint * 2 - b

            a = b
            b = c

        if c >= upper_bound:
            code.extend(upper_bound for _ in range(n_steps + 1 - len(code)))
            break

        code.append(c)

    return np.stack(code, -1)

def interval_code_search(
    lower_bound,
    upper_bound,
    n_steps,
    block_size,
    bounds_are_code_points=True,
):
    bcdf = partial(cdf, block_size=block_size)
    binv_cdf = partial(inv_cdf, block_size=block_size)

    code_builder = partial(
        build_code,
        lower_bound=lower_bound,
        upper_bound=upper_bound,
        n_steps=n_steps,
        bcdf=bcdf,
        binv_cdf=binv_cdf,
        lower_bound_is_code_point=bounds_are_code_points
    )

    lower_bracket = lower_bound + (1e-5 if bounds_are_code_points else 0)

    lower_feasible = lower_bracket
    upper_feasible = upper_bound - 1e-5

    while upper_feasible - lower_feasible > 1e-4:
        mid = (lower_feasible + upper_feasible) / 2
        code = code_builder(mid)
        infeasible = np.any(code[1:] - code[:-1] <= 0) or any(code[:-1] >= upper_bound) or code[-1] > upper_bound
        if infeasible:
            upper_feasible = mid
        else:
            lower_feasible = mid

    upper_bracket = lower_feasible
    def search_fn(val):
        code = code_builder(val)
        top = code[-1]
        prev_split = (top + code[-2]) / 2

        top_prob = bcdf(top)
        target_prob = top_prob + (top_prob - bcdf(prev_split))
        return target_prob - 1/2

    opt_a2 = root_scalar(search_fn, bracket=[lower_bracket, upper_bracket]).root
    code = code_builder(opt_a2)
    return code

def construct_af4(block_size):
    lower =  interval_code_search(-1, 0, 5, block_size)
    upper = -interval_code_search(-1, 0, 6, block_size)[::-1]
    code = np.asarray([-1., *lower, 0., *upper, 1.], dtype=np.float64)
    assert code.shape == (16,)
    return code

In [3]:
# construct_af4(64)

In [4]:
from tqdm.auto import tqdm, trange

# cdf_array = np.asarray([cdf(x, 64) for x in tqdm(np.linspace(-1.01, 1.01, num=1000, endpoint=True))])

In [5]:
import matplotlib.pyplot as plt

# plt.plot(np.linspace(-1.01, 1.01, num=1000, endpoint=True), cdf_array)

In [6]:
def weight(a, b, block_size: int):
    bs_cdf = lambda x: cdf(x, block_size)
    
    return bs_cdf(b) - bs_cdf(a)

def expectation(a, b, block_size: int):
    bs_cdf = lambda x: cdf(x, block_size)
    
    return (b * bs_cdf(b) - a * bs_cdf(a) - quad(bs_cdf, a, b, epsabs=1e-9)[0]) / (bs_cdf(b) - bs_cdf(a))

def expectation2(a, b, block_size: int):
    bs_cdf = lambda x: cdf(x, block_size)
    
    return (b**2 * bs_cdf(b) - a**2 * bs_cdf(a) - 2 * quad(lambda x: x * bs_cdf(x), a, b, epsabs=1e-9)[0]) / (bs_cdf(b) - bs_cdf(a))

def variance(a, b, block_size: int):
    return expectation2(a, b, block_size) - expectation(a, b, block_size) ** 2

def grid_variance(borders, block_size: int):
    var = 0
    for i in range(len(borders) - 1):
        a = borders[i]
        b = borders[i + 1]
        var += variance(a, b, block_size) * weight(a, b, block_size)
    return var


In [7]:
def get_edenn_grid(block_size, quality=10):
    centers = np.linspace(-1.0, +1.0, num=16, endpoint=True)
    for i in trange(10, desc="Optimizing the grid...", leave=False):
        borders = [(centers[i] + centers[i + 1]) / 2 for i in range(len(centers) - 1)]
        new_centers = np.asarray([-1.0] + [expectation(borders[i], borders[i + 1], gs) for i in range(len(borders) - 1)] + [+1.0])
        centers = new_centers
    l2 = grid_variance([-0.999999999] + borders + [+0.999999999], gs)
    
    return {"centers": centers, "l2": l2}

grids = {}

for gs in tqdm([32, 64, 128, 256, 512, 1024, 2048, 4096], desc="Iterating gs..."):
    result = get_edenn_grid(gs)
    
    print(f"{gs=} {result['l2']=}")
    grids[gs] = result


Optimizing the grid...:   0%|          | 0/10 [00:00<?, ?it/s]

gs=32 l2=0.0012927785480661255


Optimizing the grid...:   0%|          | 0/10 [00:00<?, ?it/s]

gs=64 l2=0.001279107737482122


Optimizing the grid...:   0%|          | 0/10 [00:00<?, ?it/s]

gs=128 l2=0.0012519816027263392


Optimizing the grid...:   0%|          | 0/10 [00:00<?, ?it/s]

gs=256 l2=0.001218730742359411


Optimizing the grid...:   0%|          | 0/10 [00:00<?, ?it/s]

gs=512 l2=0.001183332056735692


Optimizing the grid...:   0%|          | 0/10 [00:00<?, ?it/s]

gs=1024 l2=0.0011478979389850379


Optimizing the grid...:   0%|          | 0/10 [00:00<?, ?it/s]

gs=2048 l2=0.0011135029005145383


Optimizing the grid...:   0%|          | 0/10 [00:00<?, ?it/s]

gs=4096 l2=0.0010806469490479672
