In [6]:
def calc_tile_histogram(tile, hist_size, clip_limit):
    """Compute histogram for one tile, apply clipping & redistribution."""
    hist, _ = np.histogram(tile, bins=hist_size, range=(0, hist_size))

    if clip_limit > 0:
        # clip histogram
        excess = np.maximum(hist - clip_limit, 0)
        clipped = np.sum(excess)
        hist = np.minimum(hist, clip_limit)

        # redistribute excess
        redist_batch = clipped // hist_size
        hist += redist_batch

        residual = clipped - redist_batch * hist_size
        if residual > 0:
            step = max(1, hist_size // residual)
            for i in range(0, hist_size, step):
                if residual <= 0:
                    break
                hist[i] += 1
                residual -= 1
    return hist

In [7]:
def calc_lut_from_hist(hist, lut_scale, hist_size, dtype):
    """Build LUT by cumulative sum of histogram."""
    lut = np.cumsum(hist) * lut_scale
    return np.clip(lut, 0, hist_size - 1).astype(dtype)


In [8]:
def compute_luts(src, tiles_x, tiles_y, clip_limit, hist_size, lut_scale, dtype):
    """Compute LUT for each tile."""
    h, w = src.shape
    tile_h = h // tiles_y
    tile_w = w // tiles_x

    luts = np.zeros((tiles_y, tiles_x, hist_size), dtype=dtype)

    for ty in range(tiles_y):
        for tx in range(tiles_x):
            tile = src[ty*tile_h:(ty+1)*tile_h, tx*tile_w:(tx+1)*tile_w]
            hist = calc_tile_histogram(tile, hist_size, clip_limit)
            lut = calc_lut_from_hist(hist, lut_scale, hist_size, dtype)
            luts[ty, tx] = lut
    return luts, (tile_h, tile_w)

In [9]:
def interpolate_apply(src, luts, tile_size, tiles_x, tiles_y, shift, dtype):
    """Apply CLAHE by bilinear interpolation between neighboring LUTs."""
    h, w = src.shape
    dst = np.zeros_like(src, dtype=dtype)
    tile_h, tile_w = tile_size

    for y in range(h):
        tyf = y / tile_h - 0.5
        ty1 = max(int(np.floor(tyf)), 0)
        ty2 = min(ty1 + 1, tiles_y - 1)
        ya = tyf - np.floor(tyf)
        ya1 = 1.0 - ya

        for x in range(w):
            txf = x / tile_w - 0.5
            tx1 = max(int(np.floor(txf)), 0)
            tx2 = min(tx1 + 1, tiles_x - 1)
            xa = txf - np.floor(txf)
            xa1 = 1.0 - xa

            val = src[y, x] >> shift

            # bilinear interpolation from 4 LUTs
            res = (
                luts[ty1, tx1, val] * xa1 * ya1 +
                luts[ty1, tx2, val] * xa  * ya1 +
                luts[ty2, tx1, val] * xa1 * ya  +
                luts[ty2, tx2, val] * xa  * ya
            )

            dst[y, x] = np.clip(round(res), 0, (1 << (16 - shift)) - 1) << shift
    return dst

In [10]:
def clahe(src, clip_limit=40.0, tiles_x=8, tiles_y=8):
    """Full CLAHE pipeline."""
    assert src.dtype in (np.uint8, np.uint16)

    hist_size = 256 if src.dtype == np.uint8 else 65536
    tile_size = (src.shape[0] // tiles_y, src.shape[1] // tiles_x)
    tile_area = tile_size[0] * tile_size[1]

    lut_scale = (hist_size - 1) / float(tile_area)

    if clip_limit > 0.0:
        clip_limit = max(int(clip_limit * tile_area / hist_size), 1)
    else:
        clip_limit = 0

    shift = 0  # in OpenCV code, shift can reduce bins (not used here)

    luts, tile_size = compute_luts(src, tiles_x, tiles_y, clip_limit, hist_size, lut_scale, src.dtype)
    dst = interpolate_apply(src, luts, tile_size, tiles_x, tiles_y, shift, src.dtype)
    return dst

In [None]:
container = av.open(FILE)


In [None]:
import os
import math
import av
import numpy as np
import cv2


def pad_reflect_101(img, tiles_x, tiles_y):
    """Pad so H and W are divisible by tiles grid, like BORDER_REFLECT_101."""
    h, w = img.shape
    pad_y = (tiles_y - (h % tiles_y))
    pad_x = (tiles_x - (w % tiles_x)) 
    if pad_y or pad_x:
        img = np.pad(img, ((0, pad_y), (0, pad_x)), mode='reflect')
    return img, (h, w)  # keep original size to crop back later


def clip_limit_to_count(clip_limit_ui, tile_area, hist_size):
    """OpenCV converts UI clipLimit to a per-tile count."""
    if clip_limit_ui <= 0.0:
        return 0
    count = int(clip_limit_ui * tile_area / hist_size)
    return max(count, 1)


def calc_tile_histogram(tile, hist_size, clip_count):
    """Histogram + clip + redistribute (like OpenCV's CPU path)."""
    # Fast histogram
    hist, _ = np.histogram(tile, bins=hist_size, range=(0, hist_size))

    if clip_count > 0:

        excess = np.maximum(hist - clip_count, 0)
        clipped_total = int(excess.sum())
        # clip
        hist = np.minimum(hist, clip_count)

        if clipped_total:
            # even batch
            redist_batch = clipped_total // hist_size
            hist += redist_batch
            # remainder spread
            residual = clipped_total - redist_batch * hist_size
    
            if residual > 0:
                step = max(1, hist_size // residual)
                # increment roughly uniformly
                for i in range(0, hist_size, step):
                    if residual <= 0:
                        break
                    hist[i] += 1
                    residual -= 1
    return hist


def lut_from_hist(hist, lut_scale, hist_size, dtype):
    """Cumulative sum with OpenCV-like rounding & clamp."""
    lut = np.rint(np.cumsum(hist) * lut_scale)  # round-to-nearest
    return np.clip(lut, 0, hist_size - 1).astype(dtype)


def compute_luts(src_padded, tiles_x, tiles_y, clip_count, hist_size, lut_scale, dtype):
    """Compute LUT per tile on the padded image."""
    H, W = src_padded.shape
    tile_h = H // tiles_y
    tile_w = W // tiles_x
    luts = np.empty((tiles_y, tiles_x, hist_size), dtype=dtype)

    for ty in range(tiles_y):
        y0 = ty * tile_h
        y1 = y0 + tile_h
        for tx in range(tiles_x):
            x0 = tx * tile_w
            x1 = x0 + tile_w
            tile = src_padded[y0:y1, x0:x1]
            hist = calc_tile_histogram(tile, hist_size, clip_count)
            luts[ty, tx] = lut_from_hist(hist, lut_scale, hist_size, dtype)
    return luts, (tile_h, tile_w)


def interpolate_apply(src_padded, luts, tile_size, tiles_x, tiles_y, shift, out_dtype):
    """Bilinear interpolation of 4 neighboring tile LUTs, like OpenCV."""
    H, W = src_padded.shape
    tile_h, tile_w = tile_size
    dst = np.empty_like(src_padded, dtype=out_dtype)

    inv_tw = 1.0 / tile_w
    inv_th = 1.0 / tile_h

    # Precompute X terms once (like OpenCV does)
    xa = np.empty(W, np.float32)
    xa1 = np.empty(W, np.float32)
    ind1 = np.empty(W, np.int32)
    ind2 = np.empty(W, np.int32)
    lut_stride = luts.shape[2]  # hist_size

    for x in range(W):
        txf = x * inv_tw - 0.5
        tx1 = math.floor(txf)
        tx2 = tx1 + 1
        frac = txf - tx1
        xa[x] = frac
        xa1[x] = 1.0 - frac
        tx1 = max(tx1, 0)
        tx2 = min(tx2, tiles_x - 1)
        ind1[x] = tx1 * lut_stride
        ind2[x] = tx2 * lut_stride

    for y in range(H):
        tyf = y * inv_th - 0.5
        ty1 = math.floor(tyf)
        ty2 = ty1 + 1
        ya = tyf - ty1
        ya1 = 1.0 - ya
        ty1 = max(ty1, 0)
        ty2 = min(ty2, tiles_y - 1)

        lutPlane1 = luts[ty1]  # shape: [tiles_x, hist_size]
        lutPlane2 = luts[ty2]

        src_row = src_padded[y]
        dst_row = dst[y]

        # vectorized over x is tricky due to per-pixel indexing; do a tight loop
        for x in range(W):
            val = int(src_row[x]) >> shift
            # indices inside LUT planes
            i1 = ind1[x] + val
            i2 = ind2[x] + val

            # bilinear interp over 4 LUTs
            r = (lutPlane1.flat[i1] * xa1[x] + lutPlane1.flat[i2] * xa[x]) * ya1 + \
                (lutPlane2.flat[i1] * xa1[x] + lutPlane2.flat[i2] * xa[x]) * ya

            # saturate + shift back
            if out_dtype == np.uint8:
                dst_row[x] = np.uint8(0 if r < 0 else 255 if r > 255 else int(r + 0.5))  # round & clamp
            else:
                # 16-bit domain before << shift
                r_i = 0 if r < 0 else (65535 >> shift) if r > (65535 >> shift) else int(r + 0.5)
                dst_row[x] = np.uint16(r_i << shift)

    return dst


def clahe_like_opencv(img, clip_limit_ui, tiles_x, tiles_y, shift):
    """
    Apply CLAHE to uint8 or uint16 single-channel image.
    - Matches OpenCV behavior: padding, clip-limit normalization, bilinear LUT interpolation.
    - 'shift' lets you down-bin (e.g., shift=4 for 12-bit bins on 16-bit image).
    """

    
    bitdepth =16
    bins = 1 << (bitdepth - shift)  # effective histogram size
    dtype = img.dtype
    print(dtype)
    # Pad to tile grid
    #later: i will need to add if statment to see if i need to use this def
    padded, (orig_h, orig_w) = pad_reflect_101(img, tiles_x, tiles_y)

    # Tile geometry
    H, W = padded.shape
    tile_h = H // tiles_y
    tile_w = W // tiles_x
    tile_area = tile_h * tile_w

    # OpenCV clip-limit conversion
    clip_count = clip_limit_to_count(clip_limit_ui, tile_area, bins)

    # LUT scale, matching OpenCV: (bins - 1) / tile_area
    lut_scale = float(bins - 1) / float(tile_area)

    # Build LUTs on the padded image
    luts, tile_size = compute_luts(padded, tiles_x, tiles_y, clip_count, bins, lut_scale, dtype)

    # Interpolate-apply on the padded image
    processed = interpolate_apply(padded, luts, tile_size, tiles_x, tiles_y, shift, dtype)

    # Crop back to original size
    return processed[:orig_h, :orig_w]


# -------------------------------------------
# PyAV: MKV -> CLAHE -> (PNG sequence or MKV)
# -------------------------------------------

def process_mkv_with_av(
    input_path,
    output_frames_dir="out_frames_16bit",          # if set, saves 16-bit PNGs here
    output_mkv="out_clahe_16bit.mkv",                 # if set, writes MKV video
    tiles_x=5,
    tiles_y=5,
    clip_limit_ui=120.0,
    shift=0,
    preserve_16bit=True,             # True -> process/keep 16-bit if available
    target_fps=None,                 # None -> use source fps
    codec_8bit='libx264',            # for 8-bit MKV
    codec_16bit='ffv1'               # for 16-bit MKV (lossless)
):
    """
    Reads MKV with PyAV, applies CLAHE, and writes either 16-bit PNGs or MKV.
    - If preserve_16bit is True, frames are decoded as gray16le; else gray(8-bit).
    - For MKV output:
        * 8-bit path uses H.264 for wide compatibility.
        * 16-bit path uses FFV1 (lossless) if available.
    """
    assert output_frames_dir or output_mkv, "Provide output_frames_dir and/or output_mkv."

    container = av.open(input_path)
    vstream = container.streams.video[0]
    # Prepare MKV writer if requested
    out_container = 1
    out_stream = 1
    out_w = vstream.width
    out_h = vstream.height

    if output_mkv:
        out_container = av.open(output_mkv, mode='w')

        # 16-bit lossless path (FFV1 supports gray16le in MKV)
        out_stream = out_container.add_stream(codec_16bit)
        out_stream.width = out_w
        out_stream.height = out_h
        # pix_fmt chosen from frames (gray16le) will guide encoder

    # Prepare PNG directory if requested
    if output_frames_dir:
        os.makedirs(output_frames_dir, exist_ok=True)

    frame_index = 0

    for frame in container.decode(video=0):
        # Decode to numpy gray
        img = frame.to_ndarray(format='gray16le').astype(np.uint16)
        

        # Apply CLAHE (matches OpenCV behavior closely)
        img_clahe = clahe_like_opencv(img, clip_limit_ui=clip_limit_ui,
                                      tiles_x=tiles_x, tiles_y=tiles_y, shift=shift)
        img_clahe_n=negative(img_clahe)
        img_clahe_n_m = cv2.medianBlur(img_clahe,3)
        # Write PNGs (16-bit preserved if dtype is uint16)
        if output_frames_dir:
            out_path = os.path.join(output_frames_dir, f"frame_{frame_index:06d}.png")
            cv2.imwrite(out_path, img_clahe_n_m)
        
        # Write MKV
        if out_container is not None:
            if img_clahe.dtype == np.uint16:
                vf = av.VideoFrame.from_ndarray(img_clahe_n_m, format='gray16le')
            else:
                vf = av.VideoFrame.from_ndarray(img_clahe_n_m, format='gray')
            for packet in out_stream.encode(vf):
                out_container.mux(packet)

        frame_index += 1


    container.close()




if __name__ == "__main__":
    # Example 1: 16-bit in/out (PNG sequence + 16-bit MKV with FFV1)
    process_mkv_with_av(input_path=FILE,)


uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16
uint16


In [1]:
FILE = "vids/Almesbar_face_20_08_2025_1.mkv"

In [3]:
def negative(arr):
    """Return the negative of a numpy array (255 - value for 8-bit images)."""
    return np.abs(65535 - arr)

In [60]:
def show_freq(hist,msg):
    hist, bin_edges = np.histogram(img, bins=256, range=(0, 65536))

    plt.figure(figsize=(8, 4))
    plt.bar(bin_edges[:-1], hist, width=np.diff(bin_edges), edgecolor="black", align="edge")
    plt.xlabel("Pixel value (16-bit)")
    plt.ylabel("Frequency")
    plt.title(msg)
    plt.show()

In [None]:
def add_padding

In [92]:
def median_filter(img, kernal=3):
    padd_image=add_padding(img,kernal)
    H,W=img.shape()

    for x in range(H+1):
        for y+1 in range(W+kernal//2):
            h = x+kernal//2
            w = y+kernal//2
            padd_image[h,w]=median_of_kernal(padd_image[[x,y],[h+(h-x),w+(w-y)]],kernal)

_IncompleteInputError: incomplete input (3137542128.py, line 1)

In [None]:
def add_padding(img,kernal):
    kernal=kernal//2
    padded_image= np.pad(img, pad_width=kernal, mode='symmetric')

    return padd_image

In [None]:
def median_of_kernal(kernal,size):
    12323
    