In [10]:
# Precompute Gauss–Legendre grids (1-D or 2-D, parallelized for 2-D)
# Builds and saves GL nodes/weights on [-Umax, Umax] (1-D) and optionally
# full 2-D Gauss–Legendre (u1,u2,W) on [-Umax,Umax]^2 for any (n, Umax).
# Uses multiprocessing with row-blocked memmaps for the 2-D case.

from __future__ import annotations
import os, json, hashlib
from math import ceil
from pathlib import Path
from time import perf_counter
from multiprocessing import Pool

import numpy as np
from numpy.polynomial.legendre import leggauss

FIFT_GL2D_DIR = Path("/n/netscratch/dvorkin_lab/Lab/nephremidze/2-LISA/0-parallel/fift_gl2d")
NUM_WORKERS    = 112
TARGET_CHUNKS_PER_PROC = 8

# (n, Umax) pairs to build
PAIRS = [(8000, 200.0)]

# 1 -> only 1D (x,w); 
# 2 -> 1D + full 2D (u1,u2,W)
GL_DIM = 1

FIFT_GL2D_DIR.mkdir(parents=True, exist_ok=True)


def gl2d_paths(base_dir: Path, n: int, Umax: float) -> dict[str, Path]:

    base = base_dir / f"gl2d_n{int(n)}_U{int(Umax)}"
    return {
        "base": base,
        "u1":   base.with_suffix(".u1.npy"),
        "u2":   base.with_suffix(".u2.npy"),
        "W":    base.with_suffix(".W.npy"),
        "meta": base.with_suffix(".meta.json"),
        "x":    base.with_suffix(".x.npy"),
        "w":    base.with_suffix(".w.npy"),
    }


def sha256_file(path: Path, chunk: int = 1 << 22) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        while True:
            b = f.read(chunk)
            if not b:
                break
            h.update(b)
    return h.hexdigest()


def write_meta(paths: dict[str, Path],
               n: int,
               Umax: float,
               dtype: str = "float64",
               dim: int = 2) -> dict:

    files = {}
    sha256 = {}
    for key in ("x", "w", "u1", "u2", "W"):
        p = paths.get(key)
        if p is not None and p.exists():
            files[key] = p.name
            sha256[key] = sha256_file(p)

    meta = {
        "version": 2,
        "n": int(n),
        "u_max": float(Umax),
        "dtype": dtype,
        "dim": int(dim),
        "files": files,
        "sha256": sha256,
    }
    with open(paths["meta"], "w") as f:
        json.dump(meta, f, indent=2)
    return meta


def _fill_rows_worker(args):

    (i0, i1, n, x_path, w_path, u1_path, u2_path, W_path) = args

    x = np.load(x_path, allow_pickle=False, mmap_mode="r")
    w = np.load(w_path, allow_pickle=False, mmap_mode="r")
    u1mm = np.memmap(u1_path, dtype=np.float64, mode="r+", shape=(n*n,), order="C")
    u2mm = np.memmap(u2_path, dtype=np.float64, mode="r+", shape=(n*n,), order="C")
    Wmm  = np.memmap(W_path,  dtype=np.float64, mode="r+", shape=(n*n,), order="C")

    for i in range(i0, i1):
        row_start = i * n
        row_end   = row_start + n
        u1mm[row_start:row_end] = x[i]
        u2mm[row_start:row_end] = x[:]
        Wmm[row_start:row_end]  = w[i] * w[:]

    u1mm.flush(); u2mm.flush(); Wmm.flush()


def precompute_gl(base_dir: Path,
                  n: int,
                  Umax: float,
                  dim: int = 2,
                  nprocs: int = NUM_WORKERS,
                  target_chunks_per_proc: int = TARGET_CHUNKS_PER_PROC,
                  overwrite: bool = False) -> dict[str, Path]:
    """
    Build Gauss–Legendre quadrature data on [-Umax, Umax].

    dim=1:
        - Computes and saves 1-D nodes/weights (x,w) only:
              gl2d_n{n}_U{Umax}.x.npy
              gl2d_n{n}_U{Umax}.w.npy

    dim=2:
        - Computes and saves 1-D nodes/weights as above.
        - Builds full 2-D GL grid (u1,u2,W) using multiprocessing:
              gl2d_n{n}_U{Umax}.u1.npy
              gl2d_n{n}_U{Umax}.u2.npy
              gl2d_n{n}_U{Umax}.W.npy

    In both cases, a .meta.json file is written describing which files exist.
    """
    dim = int(dim)
    if dim not in (1, 2):
        raise ValueError("dim must be 1 or 2")

    paths = gl2d_paths(base_dir, n, Umax)
    base_dir.mkdir(parents=True, exist_ok=True)

    if dim == 1:
        need_keys = ["x", "w", "meta"]
    else:
        need_keys = ["x", "w", "u1", "u2", "W", "meta"]

    if (not overwrite) and all(paths[k].exists() for k in need_keys):
        print(f"[skip] exists: n={n}, Umax={Umax}, dim={dim}")
        return paths

    # 1) 1-D Gauss–Legendre on [-1,1], map to [-Umax, Umax]
    xi, wi = leggauss(int(n))
    x = float(Umax) * xi
    w = float(Umax) * wi
    np.save(paths["x"], x)
    np.save(paths["w"], w)

    if dim == 1:
        # 1-D only: just write metadata for x,w and return.
        write_meta(paths, n, Umax, dtype="float64", dim=1)
        print(f"[done-1D] n={n}, Umax={Umax}  ->  {paths['base']}.x/w.npy")
        return paths

    # ----- 2-D build (dim == 2) -----

    # 2) Create output memmaps (flattened length n*n)
    u1mm = np.memmap(paths["u1"], dtype=np.float64, mode="w+", shape=(n*n,), order="C")
    u2mm = np.memmap(paths["u2"], dtype=np.float64, mode="w+", shape=(n*n,), order="C")
    Wmm  = np.memmap(paths["W"],  dtype=np.float64, mode="w+", shape=(n*n,), order="C")
    u1mm[:] = 0.0; u2mm[:] = 0.0; Wmm[:] = 0.0
    u1mm.flush(); u2mm.flush(); Wmm.flush()
    del u1mm, u2mm, Wmm  # close before multiprocessing

    # 3) Partition rows
    nprocs = max(1, min(int(nprocs), int(n)))
    target_tasks = max(nprocs, target_chunks_per_proc * nprocs)
    chunk_rows = max(1, int(ceil(n / target_tasks)))
    tasks = []
    for i0 in range(0, n, chunk_rows):
        i1 = min(n, i0 + chunk_rows)
        tasks.append((i0, i1, n, str(paths["x"]), str(paths["w"]),
                      str(paths["u1"]), str(paths["u2"]), str(paths["W"])))

    print(f"[build-2D] n={n}, Umax={Umax} with {nprocs} procs, "
          f"{len(tasks)} tasks, rows/task≈{chunk_rows}")

    # 4) Launch pool
    with Pool(processes=nprocs) as pool:
        for _ in pool.imap_unordered(_fill_rows_worker, tasks):
            pass

    # 5) Metadata (now includes x,w,u1,u2,W)
    write_meta(paths, n, Umax, dtype="float64", dim=2)
    print(f"[done-2D] n={n}, Umax={Umax}  ->  {paths['base']}.*")
    return paths


# ============================ RUN BUILDS =====================================

if __name__ == "__main__":
    print("Output dir:", FIFT_GL2D_DIR)
    print("Workers   :", NUM_WORKERS)
    print("Pairs     :", [(int(n), float(U)) for (n, U) in PAIRS])
    print("GL dim    :", GL_DIM)

    t0 = perf_counter()
    for (n, U) in PAIRS:
        precompute_gl(FIFT_GL2D_DIR, int(n), float(U),
                      dim=GL_DIM,
                      nprocs=NUM_WORKERS,
                      target_chunks_per_proc=TARGET_CHUNKS_PER_PROC,
                      overwrite=False)
    t1 = perf_counter()
    print(f"Total wall time: {t1 - t0:.3f} s")


Output dir: /n/netscratch/dvorkin_lab/Lab/nephremidze/2-LISA/0-parallel/fift_gl2d
Workers   : 112
Pairs     : [(8000, 200.0)]
GL dim    : 1
[done-1D] n=8000, Umax=200.0  ->  /n/netscratch/dvorkin_lab/Lab/nephremidze/2-LISA/0-parallel/fift_gl2d/gl2d_n8000_U200.x/w.npy
Total wall time: 33.865 s
