In [1]:
import circle_bundles as cb

from dreimac import CircularCoords    

#For visualizations
cb.attach_bundle_viz_methods()


In [None]:
from __future__ import annotations

from typing import Dict, Tuple
import numpy as np
from circle_bundles.combinatorics import canon_tri
from circle_bundles.triangle_covers import _require_gudhi

from circle_bundles import TriangulationStarCover

TWOPI = 2.0 * np.pi


def _wrap_angle(x: np.ndarray) -> np.ndarray:
    twopi = 2.0 * np.pi
    return np.mod(x, twopi)


def make_t2_periodic_star_cover(
    base_points: np.ndarray,   # (N,2) angles in [0,2pi)^2 (or any real; we wrap)
    *,
    n_per_side: int = 16,      # n x n vertices
    diagonal: str = "slash",   # "slash" or "backslash"
    eps: float = 1e-12,
) -> TriangulationStarCover:
    """
    Periodic triangulation star cover on T^2.

    - Vertices are an n x n grid on [0,2pi)^2 with wrap-around.
    - Each grid square is split into two triangles.
    - Each sample lies in exactly one triangle => belongs to exactly 3 star sets.
    - Overlap order <= 3, so no 3-simplices.

    Returns a TriangulationStarCover with:
      - base_points: wrapped samples in [0,2pi)^2
      - K_preimages: same as base_points (planar coords)
      - K: gudhi SimplexTree containing the 2D triangulation
      - vertex_coords_dict: maps old vid -> (theta1,theta2)
    """
    P = np.asarray(base_points, dtype=float)
    if P.ndim != 2 or P.shape[1] != 2:
        raise ValueError(f"base_points must be (N,2). Got {P.shape}.")

    n = int(n_per_side)
    if n <= 1:
        raise ValueError("n_per_side must be >= 2.")

    diag = str(diagonal).lower().strip()
    if diag not in {"slash", "backslash"}:
        raise ValueError("diagonal must be 'slash' or 'backslash'.")

    twopi = 2.0 * np.pi
    P = _wrap_angle(P)

    # --- Vertex coordinates ---
    # old vertex id: vid(i,j) = i*n + j
    thetas1 = (np.arange(n) / n) * twopi
    thetas2 = (np.arange(n) / n) * twopi

    vertex_coords_dict: Dict[int, np.ndarray] = {}
    for i in range(n):
        for j in range(n):
            vid = i * n + j
            vertex_coords_dict[int(vid)] = np.array([thetas1[i], thetas2[j]], dtype=float)

    # --- Triangles (as triples of old vertex ids) ---
    faces = []
    for i in range(n):
        ip = (i + 1) % n
        for j in range(n):
            jp = (j + 1) % n

            v00 = i * n + j
            v10 = ip * n + j
            v01 = i * n + jp
            v11 = ip * n + jp

            if diag == "slash":
                # diagonal from (i,jp) to (ip,j):  v01--v10
                faces.append((v00, v10, v01))
                faces.append((v11, v01, v10))
            else:
                # diagonal from (i,j) to (ip,jp):  v00--v11
                faces.append((v00, v10, v11))
                faces.append((v00, v11, v01))

    # --- Gudhi simplex tree ---
    gd = _require_gudhi()
    K = gd.SimplexTree()
    for (a, b, c) in faces:
        K.insert([int(a), int(b), int(c)])

    # --- Assign each sample to one of the two triangles in its cell, and bary coords ---
    # Work in cell coordinates [0,2pi) with wrap; choose the cell by floor.
    h = twopi / n
    # cell indices 0..n-1
    ci = np.floor(P[:, 0] / h).astype(int) % n
    cj = np.floor(P[:, 1] / h).astype(int) % n

    # local coords within cell
    u = (P[:, 0] - ci * h) / h  # in [0,1)
    v = (P[:, 1] - cj * h) / h  # in [0,1)

    # vertices of the cell in actual (theta1,theta2) coords
    # We'll compute barycentric weights in *affine 2D coords* where the cell is the unit square.
    # Then lift to which triangle it lies in.

    # pick triangle by diagonal convention
    if diag == "slash":
        # slash diagonal splits unit square into:
        #  T0: (0,0),(1,0),(0,1)  i.e. u+v <= 1
        #  T1: (1,1),(0,1),(1,0)  i.e. u+v >= 1
        in_T0 = (u + v <= 1.0 + eps)
        # bary for T0 with verts A=(0,0), B=(1,0), C=(0,1):
        # p=(u,v) = wA*A + wB*B + wC*C with wA=1-u-v, wB=u, wC=v
        wA0 = 1.0 - u - v
        wB0 = u
        wC0 = v
        # bary for T1 with verts A'=(1,1), B'=(0,1), C'=(1,0):
        # easiest: use coords in that triangle:
        # weights: wA' = u+v-1, wB' = 1-u, wC' = 1-v
        wA1 = (u + v - 1.0)
        wB1 = 1.0 - u
        wC1 = 1.0 - v
        # clamp tiny negatives
        wA0 = np.maximum(wA0, 0.0); wB0 = np.maximum(wB0, 0.0); wC0 = np.maximum(wC0, 0.0)
        wA1 = np.maximum(wA1, 0.0); wB1 = np.maximum(wB1, 0.0); wC1 = np.maximum(wC1, 0.0)
    else:
        # backslash diagonal splits into:
        #  T0: (0,0),(1,0),(1,1)   region v <= u
        #  T1: (0,0),(1,1),(0,1)   region v >= u
        in_T0 = (v <= u + eps)
        # T0 verts A=(0,0), B=(1,0), C=(1,1)
        # Solve: p=(u,v) = wA*(0,0)+wB*(1,0)+wC*(1,1)
        # => u = wB+wC, v=wC, so wC=v, wB=u-v, wA=1-u
        wC0 = v
        wB0 = u - v
        wA0 = 1.0 - u
        # T1 verts A'=(0,0), B'=(1,1), C'=(0,1)
        # u = wB', v = wB'+wC'  => wB'=u, wC'=v-u, wA'=1-v
        wB1 = u
        wC1 = v - u
        wA1 = 1.0 - v
        wA0 = np.maximum(wA0, 0.0); wB0 = np.maximum(wB0, 0.0); wC0 = np.maximum(wC0, 0.0)
        wA1 = np.maximum(wA1, 0.0); wB1 = np.maximum(wB1, 0.0); wC1 = np.maximum(wC1, 0.0)

    # Normalize weights per point (numerical safety)
    def _norm3(a,b,c):
        s = a + b + c
        s = np.where(s <= 0, 1.0, s)
        return a/s, b/s, c/s

    wA0, wB0, wC0 = _norm3(wA0, wB0, wC0)
    wA1, wB1, wC1 = _norm3(wA1, wB1, wC1)

    # Build sample_tri + sample_bary in the indexing that TriangulationStarCover expects.
    # Weâ€™ll create cover.triangles as the canon_tri list of faces in *new ids* (which will be 0..n^2-1).
    # Since old_vids are already 0..n^2-1, relabel is identity.
    tris = [canon_tri(int(a), int(b), int(c)) for (a,b,c) in faces]
    tri_to_idx = {t: idx for idx, t in enumerate(tris)}

    N = P.shape[0]
    sample_tri = np.empty(N, dtype=int)
    sample_bary = np.empty((N,3), dtype=float)

    for s in range(N):
        i = int(ci[s]); j = int(cj[s])
        ip = (i + 1) % n
        jp = (j + 1) % n

        v00 = i * n + j
        v10 = ip * n + j
        v01 = i * n + jp
        v11 = ip * n + jp

        if diag == "slash":
            if bool(in_T0[s]):
                t = canon_tri(v00, v10, v01)
                sample_bary[s] = np.array([wA0[s], wB0[s], wC0[s]], dtype=float)
            else:
                t = canon_tri(v11, v01, v10)
                sample_bary[s] = np.array([wA1[s], wB1[s], wC1[s]], dtype=float)
        else:
            if bool(in_T0[s]):
                t = canon_tri(v00, v10, v11)
                sample_bary[s] = np.array([wA0[s], wB0[s], wC0[s]], dtype=float)
            else:
                t = canon_tri(v00, v11, v01)
                sample_bary[s] = np.array([wA1[s], wB1[s], wC1[s]], dtype=float)

        sample_tri[s] = int(tri_to_idx[t])

    # Build cover object
    cover = TriangulationStarCover(
        base_points=P,
        K_preimages=P,                # we are already in the flat fundamental domain
        K=K,
        vertex_coords_dict=vertex_coords_dict,
    )

    # Fast-path fill (like your S^2 builder)
    cover._relabel_vertices()
    cover.triangles = tris
    cover.sample_tri = sample_tri
    cover.sample_bary = sample_bary
    cover._build_star_sets_U()
    cover._build_pou_from_barycentric()
    return cover



# ---------------------------------------------------------------------
# Sampling + projection helpers
# ---------------------------------------------------------------------

def _get_rng(rng: Optional[np.random.Generator]) -> np.random.Generator:
    return np.random.default_rng() if rng is None else rng


def pi_T2xS1_to_T2(total_points: np.ndarray) -> np.ndarray:
    """
    Projection (theta1, theta2, phi) -> (theta1, theta2).
    """
    X = np.asarray(total_points, dtype=float)
    if X.ndim != 2 or X.shape[1] != 3:
        raise ValueError(f"total_points must be (n,3). Got {X.shape}.")
    return wrap_angles(X[:, :2])


def sample_T2(
    n: int,
    *,
    rng: Optional[np.random.Generator] = None,
    jitter: float = 0.0,
) -> np.ndarray:
    """
    Sample base points on T^2 as angles in [0,2pi)^2.

    Parameters
    ----------
    n:
        Number of samples.
    jitter:
        Optional Gaussian noise (std in radians) added before wrapping.
        Useful if you want "noisy angles" but still living on the torus.

    Returns
    -------
    base_points : (n,2) angles
    """
    if n <= 0:
        raise ValueError(f"n must be positive. Got {n}.")
    rng = _get_rng(rng)
    X = rng.uniform(0.0, TWOPI, size=(int(n), 2))
    if float(jitter) > 0:
        X = X + rng.normal(0.0, float(jitter), size=X.shape)
    return wrap_angles(X)


def sample_T2xS1_trivial(
    n: int,
    *,
    rng: Optional[np.random.Generator] = None,
    base_points: Optional[np.ndarray] = None,
    fiber_jitter: float = 0.0,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Sample points in the trivial total space T^2 x S^1, returned as angles (theta1,theta2,phi).

    If base_points is provided, uses those base samples and only samples fiber angles.

    Returns
    -------
    total_points : (n,3) angles
    base_points  : (n,2) angles
    """
    rng = _get_rng(rng)
    if base_points is None:
        B = sample_T2(n, rng=rng)
    else:
        B = wrap_angles(np.asarray(base_points, dtype=float))
        if B.ndim != 2 or B.shape[1] != 2:
            raise ValueError(f"base_points must be (n,2). Got {B.shape}.")
        if B.shape[0] != n:
            raise ValueError(f"base_points has n={B.shape[0]} but requested n={n}.")

    phi = rng.uniform(0.0, TWOPI, size=(int(n),))
    if float(fiber_jitter) > 0:
        phi = phi + rng.normal(0.0, float(fiber_jitter), size=phi.shape)

    X = np.column_stack([B[:, 0], B[:, 1], phi])
    return wrap_angles(X), B


def sample_T2xS1_bundle_oriented(
    n: int,
    *,
    k: int = 0,
    rng: Optional[np.random.Generator] = None,
    base_points: Optional[np.ndarray] = None,
    fiber_mean: float = 0.0,
    fiber_sigma: float = 1.0,
    jitter_base: float = 0.0,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Convenience sampler for an oriented circle bundle over T^2 with Euler parameter k,
    in the same spirit as our 'twist_by_other_base' convention.

    This is NOT "the" canonical geometric model (there isn't a single one in these coords),
    but it's very useful for experiments because it couples fiber to base in a controlled way:

        phi = fiber_mean + Normal(0, fiber_sigma) + k * theta1 * theta2 / (2pi)

    then wrapped mod 2pi.

    If k=0, this reduces (up to noise) to the trivial product sampler.

    Parameters
    ----------
    k:
        Intended Euler class parameter used in the coupling.
    fiber_mean, fiber_sigma:
        Noise model for fiber angle.
    jitter_base:
        Optional Gaussian jitter (radians) applied to base before wrapping.

    Returns
    -------
    total_points : (n,3) angles
    base_points  : (n,2) angles
    """
    rng = _get_rng(rng)
    if base_points is None:
        B = sample_T2(n, rng=rng, jitter=float(jitter_base))
    else:
        B = wrap_angles(np.asarray(base_points, dtype=float))
        if B.ndim != 2 or B.shape[1] != 2:
            raise ValueError(f"base_points must be (n,2). Got {B.shape}.")
        if B.shape[0] != n:
            raise ValueError(f"base_points has n={B.shape[0]} but requested n={n}.")
        if float(jitter_base) > 0:
            B = wrap_angles(B + rng.normal(0.0, float(jitter_base), size=B.shape))

    # base coords
    th1 = B[:, 0]
    th2 = B[:, 1]

    # fiber noise + mild bilinear coupling (scaled to radians)
    noise = rng.normal(float(fiber_mean), float(fiber_sigma), size=(int(n),))
    coupling = float(int(k)) * (th1 * th2) / TWOPI
    phi = noise + coupling

    X = np.column_stack([th1, th2, phi])
    return wrap_angles(X), B


In [None]:
from circle_bundles.t2_bundle_metrics import *
data, base_points = sample_T2xS1_trivial(10000)

cover = make_t2_periodic_star_cover(base_points, n_per_side=9, diagonal="slash")
summ = cover.summarize(plot=True)


In [None]:
data, base_points = sample_T2xS1_trivial(10000)

cover = make_t2_periodic_star_cover(base_points, n_per_side=9, diagonal="slash")
summ = cover.summarize(plot=True)

total_metric = T2_circle_bundle_metric_oriented(k=1, window=2, fiber_weight=1.0)

#Compute local coordinates on the dataset using kb_metric
bundle = cb.build_bundle(
    data,
    cover,
#    CircularCoords_cls=CircularCoords,    #optionally use sparse cc's
    show=True,
    total_metric = total_metric
)


In [None]:
#Construct local circular coordinates and model transitions as O(2) matrices
bundle.compare_trivs()
