# Multiplexing Pixels Experiments

## Generating multiplexed dataset

In [None]:
import math
import numpy as np
import imageio.v2 as imageio
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from typing import Tuple, List


def get_comap(
    num_lens: int, d_lens_sensor: int, H: int, W: int
) -> Tuple[np.ndarray, List[int]]:
    # Verify input and calculate the grid dimensions
    if math.sqrt(num_lens) ** 2 == num_lens:
        num_lenses_yx = [int(math.sqrt(num_lens)), int(math.sqrt(num_lens))]
    else:
        print("Number of sublens should be a square number")
        assert False

    # Calculate microlens dimensions in pixels based on d_lens_sensor
    base_microlens_size = min(H // num_lenses_yx[0], W // num_lenses_yx[1]) // 12
    microlens_height = int(base_microlens_size * d_lens_sensor)
    microlens_height = microlens_height - (
        microlens_height % 2
    )  # Make dimensions even for convenience
    microlens_width = microlens_height  # Keep microlenses square
    comap_yx = -np.ones((num_lens, H, W, 2))

    # Calculate positions for microlenses to distribute from edge to edge
    if num_lenses_yx[0] > 1:
        y_positions = np.linspace(
            microlens_height // 2,  # First lens centered at top edge + half lens height
            H
            - microlens_height
            // 2,  # Last lens centered at bottom edge - half lens height
            num_lenses_yx[0],
        )
    else:
        y_positions = np.array([H // 2])  # If only one row, place it in the center
    if num_lenses_yx[1] > 1:
        x_positions = np.linspace(
            microlens_width // 2,  # First lens centered at left edge + half lens width
            W
            - microlens_width
            // 2,  # Last lens centered at right edge - half lens width
            num_lenses_yx[1],
        )
    else:
        x_positions = np.array([W // 2])  # If only one column, place it in the center

    for i in range(num_lens):
        row, col = i // num_lenses_yx[1], i % num_lenses_yx[1]
        center_y, center_x = int(y_positions[row]), int(x_positions[col])
        start_y = int(max(0, center_y - microlens_height // 2))
        end_y = int(min(H, center_y + microlens_height // 2))
        start_x = int(max(0, center_x - microlens_width // 2))
        end_x = int(min(W, center_x + microlens_width // 2))

        for y in range(start_y, end_y):
            for x in range(start_x, end_x):
                local_y, local_x = y - start_y, x - start_x
                comap_yx[i, y, x, 0] = local_y
                comap_yx[i, y, x, 1] = local_x

    # Return the original dimension as second return value
    dim_lens_lf_yx = [microlens_height, microlens_width]
    return comap_yx, dim_lens_lf_yx


def read_images(num_lens, model_path, base):
    images = []
    for j in range(num_lens):
        sub_lens_path = f"r_{base}_{j}.png"
        im_gt = (
            imageio.imread(f"{model_path}/{sub_lens_path}").astype(np.float32) / 255.0
        )
        im_tensor = torch.from_numpy(im_gt[:, :, :3]).permute(2, 0, 1).to(device)
        images.append(im_tensor)  # Keep only RGB channels

    return images


def get_max_overlap(comap_yx, num_lens, H, W):
    overlap_count = torch.zeros(H, W, dtype=torch.int32, device=device)
    for i in range(num_lens):
        valid_mask = comap_yx[i][:, :, 1] != -1
        overlap_count += valid_mask
    return overlap_count.max()


def generate_sub_images(images, comap_yx, dim_lens_lf_yx, num_lens, sensor_size):
    sub_images = torch.zeros(
        num_lens, 3, sensor_size, sensor_size, device=device, dtype=torch.float32
    )

    # Create a mapping from comap_yx index to images index
    grid_size = int(math.sqrt(num_lens))
    idx = torch.arange(grid_size, device=device)
    grid_i, grid_j = torch.meshgrid(idx, idx, indexing="ij")
    mapping = ((grid_size - 1 - grid_i) + (grid_size - 1 - grid_j) * grid_size).reshape(
        -1
    )

    images_tensor = torch.stack(images, dim=0).to(device)
    selected_images = images_tensor[mapping]
    resized_images = F.interpolate(
        selected_images,
        size=(dim_lens_lf_yx[0], dim_lens_lf_yx[1]),
        mode="bilinear",
        align_corners=False,
    )

    for i in range(num_lens):
        # sub_image = torch.zeros(3, sensor_size, sensor_size, device=device, dtype=torch.float32)
        y_coords = comap_yx[i, :, :, 0]
        x_coords = comap_yx[i, :, :, 1]

        valid_mask = (
            (y_coords != -1)
            & (x_coords != -1)
            & (y_coords >= 0)
            & (y_coords < dim_lens_lf_yx[0])
            & (x_coords >= 0)
            & (x_coords < dim_lens_lf_yx[1])
        )

        if valid_mask.any():
            y_indices, x_indices = torch.where(valid_mask)
            y_src = y_coords[valid_mask].int()
            x_src = x_coords[valid_mask].int()
        sub_images[i, :, y_indices, x_indices] = resized_images[i, :, y_src, x_src]

    return sub_images


def generate(images, comap_yx, dim_lens_lf_yx, num_lens, H, W, max_overlap):
    grid_size = int(math.sqrt(num_lens))
    idx = torch.arange(grid_size, device=device)
    grid_i, grid_j = torch.meshgrid(idx, idx, indexing="ij")
    mapping = ((grid_size - 1 - grid_i) + (grid_size - 1 - grid_j) * grid_size).reshape(
        -1
    )

    images_tensor = torch.stack(images, dim=0).to(device)
    selected_images = images_tensor[mapping]
    resized_images = F.interpolate(
        selected_images,
        size=(dim_lens_lf_yx[0], dim_lens_lf_yx[1]),
        mode="bilinear",
        align_corners=False,
    )

    output_image = torch.zeros(3, H, W, device=device, dtype=torch.float32)
    for i in range(num_lens):
        y_coords = comap_yx[i, :, :, 0]
        x_coords = comap_yx[i, :, :, 1]

        valid_mask = (
            (y_coords != -1)
            & (x_coords != -1)
            & (y_coords >= 0)
            & (y_coords < dim_lens_lf_yx[0])
            & (x_coords >= 0)
            & (x_coords < dim_lens_lf_yx[1])
        )

        # Only process this microlens if there are any valid mapping positions.
        if valid_mask.any():
            # Get 2D indices within the sub-image where valid_mask is True.
            y_indices, x_indices = torch.where(valid_mask)
            y_src = y_coords[valid_mask].long()
            x_src = x_coords[valid_mask].long()
            output_image[:, y_indices, x_indices] += resized_images[i, :, y_src, x_src]

    output_image = torch.div(output_image, max_overlap)
    return output_image


def plot_sub_images(sub_images, num_lens):
    grid_size = int(np.sqrt(num_lens))
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(15, 15))

    for i in range(grid_size):
        for j in range(grid_size):
            idx = i * grid_size + j
            axes[i, j].imshow(sub_images[idx].cpu().permute(1, 2, 0).numpy())
            axes[i, j].set_title(f"Microlens {idx}")
            axes[i, j].axis("off")

    plt.tight_layout()
    plt.show()


NUM_LENS = 16
SENSOR_SIZE = 800
d_lens_sensor = 18
model_path = "/home/wl757/multiplexed-pixels/plenoxels/blender_data/lego_gen12/new_multiplexed_views"
base = "59"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generate coordinate mapping
comap_yx, dim_lens_lf_yx = get_comap(NUM_LENS, d_lens_sensor, SENSOR_SIZE, SENSOR_SIZE)
print(dim_lens_lf_yx)
# plt.imshow(comap_yx[0, :, :, 0])
# plt.colorbar()
# plt.show()
comap_yx = torch.from_numpy(comap_yx).to(device)

# Generate and get sub-images
images = read_images(NUM_LENS, model_path, base)
max_overlap = get_max_overlap(comap_yx, NUM_LENS, SENSOR_SIZE, SENSOR_SIZE)
combined = generate(
    images, comap_yx, dim_lens_lf_yx, NUM_LENS, SENSOR_SIZE, SENSOR_SIZE, max_overlap
)
print(max_overlap)

print(combined.min(), combined.max())
plt.imshow(combined.cpu().permute(1, 2, 0).numpy())
plt.axis("off")
plt.show()

In [None]:
import torch
from kornia.enhance.equalization import equalize_clahe

with torch.autograd.detect_anomaly():
    img = torch.rand((2, 3, 10, 20), requires_grad=True)
    res = equalize_clahe(img, slow_and_differentiable=True)
    res.sum().backward()
    print(img.grad, img.grad.sum())

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch


def get_comap(num_lens, d_lens_sensor, H, W):
    if math.sqrt(num_lens) ** 2 == num_lens:
        num_lenses_yx = [int(math.sqrt(num_lens)), int(math.sqrt(num_lens))]
    else:
        print("Number of sublens should be a square number")
        assert False

    base_microlens_size = min(H // num_lenses_yx[0], W // num_lenses_yx[1]) // 12
    microlens_height = int(base_microlens_size * d_lens_sensor)
    microlens_height = microlens_height - (
        microlens_height % 2
    )  # Make dimensions even for convenience
    microlens_width = microlens_height  # Keep microlenses square
    comap_yx = -np.ones((num_lens, H, W, 2))

    if num_lenses_yx[0] > 1:
        y_positions = np.linspace(
            microlens_height // 2,  # First lens centered at top edge + half lens height
            H
            - microlens_height
            // 2,  # Last lens centered at bottom edge - half lens height
            num_lenses_yx[0],
        )
    else:
        y_positions = np.array([H // 2])  # If only one row, place it in the center
    if num_lenses_yx[1] > 1:
        x_positions = np.linspace(
            microlens_width // 2,  # First lens centered at left edge + half lens width
            W
            - microlens_width
            // 2,  # Last lens centered at right edge - half lens width
            num_lenses_yx[1],
        )
    else:
        x_positions = np.array([W // 2])  # If only one column, place it in the center

    for i in range(num_lens):
        row, col = i // num_lenses_yx[1], i % num_lenses_yx[1]
        center_y, center_x = int(y_positions[row]), int(x_positions[col])
        start_y = int(max(0, center_y - microlens_height // 2))
        end_y = int(min(H, center_y + microlens_height // 2))
        start_x = int(max(0, center_x - microlens_width // 2))
        end_x = int(min(W, center_x + microlens_width // 2))

        for y in range(start_y, end_y):
            for x in range(start_x, end_x):
                local_y, local_x = y - start_y, x - start_x
                comap_yx[i, y, x, 0] = local_y
                comap_yx[i, y, x, 1] = local_x

    dim_lens_lf_yx = [microlens_height, microlens_width]
    return comap_yx, dim_lens_lf_yx


def read_images(num_lens, model_path, base):
    images = []
    for j in range(num_lens):
        sub_lens_path = f"r_{base}_{j}.png"
        im_gt = (
            imageio.imread(f"{model_path}/{sub_lens_path}").astype(np.float32) / 255.0
        )
        im_tensor = torch.from_numpy(im_gt[:, :, :3]).permute(2, 0, 1).to(device)
        images.append(im_tensor)  # Keep only RGB channels

    return images


def get_max_overlap(comap_yx, num_lens, H, W):
    overlap_count = torch.zeros(H, W, dtype=torch.int32, device=device)
    for i in range(num_lens):
        valid_mask = comap_yx[i][:, :, 1] != -1
        overlap_count += valid_mask
    return overlap_count.max()


NUM_LENS = 16
SENSOR_SIZE = 400
d_lens_sensor = 19
model_path = "/home/wl757/multiplexed-pixels/plenoxels/blender_data/lego_gen12/train_multilens_16_black"
base = "59"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generate coordinate mapping
comap_yx, dim_lens_lf_yx = get_comap(NUM_LENS, d_lens_sensor, SENSOR_SIZE, SENSOR_SIZE)
print(dim_lens_lf_yx)
# plt.imshow(comap_yx[0, :, :, 0])
# plt.colorbar()
# plt.show()
comap_yx = torch.from_numpy(comap_yx).to(device)

# Generate and get sub-images
images = read_images(NUM_LENS, model_path, base)
sub_images = generate_sub_images(
    images, comap_yx, dim_lens_lf_yx, NUM_LENS, SENSOR_SIZE
)
max_overlap = get_max_overlap(comap_yx, NUM_LENS, SENSOR_SIZE, SENSOR_SIZE)
print(max_overlap)

# Plot all sub-images
# plot_sub_images(sub_images, NUM_LENS)
combined = torch.sum(sub_images, axis=0)
combined = combined / combined.max()
plt.imshow(combined.cpu().permute(1, 2, 0).numpy())
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from kornia import enhance
import torch


def get_comap(num_lens, d_lens_sensor, H, W):
    if math.sqrt(num_lens) ** 2 == num_lens:
        num_lenses_yx = [int(math.sqrt(num_lens)), int(math.sqrt(num_lens))]
    else:
        print("Number of sublens should be a square number")
        assert False

    # Calculate microlens dimensions in pixels based on d_lens_sensor
    base_microlens_size = min(H // num_lenses_yx[0], W // num_lenses_yx[1]) // 12
    microlens_height = int(base_microlens_size * d_lens_sensor)
    microlens_height = microlens_height - (
        microlens_height % 2
    )  # Make dimensions even for convenience
    microlens_width = microlens_height  # Keep microlenses square
    comap_yx = -np.ones((num_lens, H, W, 2))

    # Calculate positions for microlenses to distribute from edge to edge
    if num_lenses_yx[0] > 1:
        y_positions = np.linspace(
            microlens_height // 2,  # First lens centered at top edge + half lens height
            H
            - microlens_height
            // 2,  # Last lens centered at bottom edge - half lens height
            num_lenses_yx[0],
        )
    else:
        y_positions = np.array([H // 2])  # If only one row, place it in the center
    if num_lenses_yx[1] > 1:
        x_positions = np.linspace(
            microlens_width // 2,  # First lens centered at left edge + half lens width
            W
            - microlens_width
            // 2,  # Last lens centered at right edge - half lens width
            num_lenses_yx[1],
        )
    else:
        x_positions = np.array([W // 2])  # If only one column, place it in the center

    for i in range(num_lens):
        row, col = i // num_lenses_yx[1], i % num_lenses_yx[1]
        center_y, center_x = int(y_positions[row]), int(x_positions[col])
        start_y = int(max(0, center_y - microlens_height // 2))
        end_y = int(min(H, center_y + microlens_height // 2))
        start_x = int(max(0, center_x - microlens_width // 2))
        end_x = int(min(W, center_x + microlens_width // 2))

        for y in range(start_y, end_y):
            for x in range(start_x, end_x):
                local_y, local_x = y - start_y, x - start_x
                comap_yx[i, y, x, 0] = local_y
                comap_yx[i, y, x, 1] = local_x

    # Return the original dimension as second return value
    dim_lens_lf_yx = [microlens_height, microlens_width]
    return comap_yx, dim_lens_lf_yx


def read_images(num_lens, model_path, base):
    images = []
    for j in range(num_lens):
        sub_lens_path = f"r_{base}_{j}.png"
        im_gt = (
            imageio.imread(f"{model_path}/{sub_lens_path}").astype(np.float32) / 255.0
        )
        images.append(im_gt[:, :, :3])  # Keep only RGB channels
    return images


def generate_alpha_map(comap_yx, num_lens, H, W):
    overlap_count = np.zeros((H, W), dtype=np.int32)

    for i in range(num_lens):
        valid_mask = comap_yx[i, :, :, 0] != -1
        overlap_count += valid_mask

    alpha_map = np.zeros((H, W))
    non_zero_mask = overlap_count > 0
    alpha_map[non_zero_mask] = 1.0 / overlap_count[non_zero_mask]
    return alpha_map, overlap_count


NUM_LENS = 16
SENSOR_SIZE = 800
d_lens_sensor = 20
model_path = "/home/wl757/multiplexed-pixels/plenoxels/blender_data/lego_gen12/train_multilens_16_black"
base = "59"

# Generate coordinate mapping
comap_yx, dim_lens_lf_yx = get_comap(NUM_LENS, d_lens_sensor, SENSOR_SIZE, SENSOR_SIZE)
# print(dim_lens_lf_yx)
# plt.imshow(comap_yx[0, :, :, 0])
# plt.colorbar()
# plt.show()

alpha_map, overlap_count = generate_alpha_map(
    comap_yx, NUM_LENS, SENSOR_SIZE, SENSOR_SIZE
)
# plt.imshow(alpha_map)
# plt.colorbar()
# plt.show()

# Generate and get sub-images
images = read_images(NUM_LENS, model_path, base)
sub_images = generate_sub_images(
    images, comap_yx, dim_lens_lf_yx, NUM_LENS, SENSOR_SIZE
)

# Plot all sub-images
# plot_sub_images(sub_images, NUM_LENS)
combined = np.sum(sub_images, axis=0)
output = np.zeros((SENSOR_SIZE, SENSOR_SIZE, 3))
for image in sub_images:
    rgb = image[:, :, :3]
    alpha = image[:, :, 3]
    output += rgb * alpha[:, :, np.newaxis]
combined_rgb = combined[:, :, :3]
print("overlap max: ", overlap_count.max())
print("overlap min: ", overlap_count.min())
combined_output = combined_rgb / overlap_count.max()
# combined_output = combined_output / combined_output.max()
print(combined_output.min(), combined_output.max())
plt.imshow(combined_rgb / combined_rgb.max())
plt.show()
plt.imshow(combined_output)
plt.show()
plt.imshow(
    enhance.equalize(torch.from_numpy(combined_output).permute(2, 0, 1))
    .permute(1, 2, 0)
    .detach()
    .numpy()
)
plt.show()
plt.imshow(combined)
plt.show()
plt.imshow(output)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_comap(num_lens, d_lens_sensor, H, W):
    if math.sqrt(num_lens) ** 2 == num_lens:
        num_lenses_yx = [int(math.sqrt(num_lens)), int(math.sqrt(num_lens))]
    else:
        print("Number of sublens should be a square number")
        assert False

    base_microlens_size = min(H // num_lenses_yx[0], W // num_lenses_yx[1]) // 12
    microlens_height = int(base_microlens_size * d_lens_sensor)
    microlens_height = microlens_height - (
        microlens_height % 2
    )  # Make dimensions even for convenience
    microlens_width = microlens_height  # Keep microlenses square
    comap_yx = -np.ones((num_lens, H, W, 2))

    if num_lenses_yx[0] > 1:
        y_positions = np.linspace(
            microlens_height // 2,  # First lens centered at top edge + half lens height
            H
            - microlens_height
            // 2,  # Last lens centered at bottom edge - half lens height
            num_lenses_yx[0],
        )
    else:
        y_positions = np.array([H // 2])  # If only one row, place it in the center
    if num_lenses_yx[1] > 1:
        x_positions = np.linspace(
            microlens_width // 2,  # First lens centered at left edge + half lens width
            W
            - microlens_width
            // 2,  # Last lens centered at right edge - half lens width
            num_lenses_yx[1],
        )
    else:
        x_positions = np.array([W // 2])  # If only one column, place it in the center

    for i in range(num_lens):
        row, col = i // num_lenses_yx[1], i % num_lenses_yx[1]
        center_y, center_x = int(y_positions[row]), int(x_positions[col])
        start_y = int(max(0, center_y - microlens_height // 2))
        end_y = int(min(H, center_y + microlens_height // 2))
        start_x = int(max(0, center_x - microlens_width // 2))
        end_x = int(min(W, center_x + microlens_width // 2))

        for y in range(start_y, end_y):
            for x in range(start_x, end_x):
                local_y, local_x = y - start_y, x - start_x
                comap_yx[i, y, x, 0] = local_y
                comap_yx[i, y, x, 1] = local_x

    # Return the original dimension as second return value
    dim_lens_lf_yx = [microlens_height, microlens_width]
    return comap_yx, dim_lens_lf_yx


def read_images(num_lens, model_path, base):
    images = []
    for j in range(num_lens):
        sub_lens_path = f"r_{base}_{j}.png"
        im_gt = (
            imageio.imread(f"{model_path}/{sub_lens_path}").astype(np.float32) / 255.0
        )
        im_tensor = torch.from_numpy(im_gt[:, :, :3]).permute(2, 0, 1).to(device)
        images.append(im_tensor)  # Keep only RGB channels

    return images


def generate_alpha_map(comap_yx, num_lens, H, W):
    overlap_count = np.zeros((H, W), dtype=np.int32)

    for i in range(num_lens):
        valid_mask = comap_yx[i, :, :, 0] != -1
        overlap_count += valid_mask

    alpha_map = np.zeros((H, W))
    non_zero_mask = overlap_count > 0
    alpha_map[non_zero_mask] = 1.0 / overlap_count[non_zero_mask]
    return alpha_map


def generate(images, comap_yx, dim_lens_lf_yx, num_lens, sensor_size, alpha_map):
    # Compute the grid size (assumes num_lens is a perfect square)
    grid_size = int(math.sqrt(num_lens))

    # Vectorize the mapping from comap_yx index to images index.
    # For each grid coordinate (i, j) in order, the mapping is computed as:
    #   mapping[i * grid_size + j] = (grid_size - 1 - i) + (grid_size - 1 - j) * grid_size
    idx = torch.arange(grid_size, device=device)
    grid_i, grid_j = torch.meshgrid(idx, idx, indexing="ij")
    mapping = ((grid_size - 1 - grid_i) + (grid_size - 1 - grid_j) * grid_size).reshape(
        -1
    )

    # Stack images into one tensor and select the images as ordered by the mapping.
    # images_tensor shape: (N, 3, H, W) --> selected_images shape: (num_lens, 3, H, W)
    images_tensor = torch.stack(images, dim=0).to(device)
    selected_images = images_tensor[mapping]

    # Resize all selected images at once using vectorized interpolation.
    # The output shape will be (num_lens, 3, lens_H, lens_W)
    resized_images = F.interpolate(
        selected_images,
        size=(dim_lens_lf_yx[0], dim_lens_lf_yx[1]),
        mode="bilinear",
        align_corners=False,
    )

    output_image = torch.zeros(
        3, sensor_size, sensor_size, device=device, dtype=torch.float32
    )

    # Loop over each microlens and accumulate its contribution.
    for i in range(num_lens):
        # Extract the coordinate maps for this microlens (each of shape (sensor_size, sensor_size)).
        # The last dimension in comap_yx_torch holds [y, x] coordinates.
        y_coords = comap_yx[i, :, :, 0]
        x_coords = comap_yx[i, :, :, 1]

        # Build a mask that marks valid pixel positions:
        # - Coordinates must not equal -1
        # - Must be within the bounds of the resized image dimensions.
        valid_mask = (
            (y_coords != -1)
            & (x_coords != -1)
            & (y_coords >= 0)
            & (y_coords < dim_lens_lf_yx[0])
            & (x_coords >= 0)
            & (x_coords < dim_lens_lf_yx[1])
        )

        # Only process this microlens if there are any valid mapping positions.
        if valid_mask.any():
            # Get 2D indices within the sub-image where valid_mask is True.
            y_indices, x_indices = torch.where(valid_mask)

            y_src = y_coords[valid_mask].long()
            x_src = x_coords[valid_mask].long()

            output_image[:, y_indices, x_indices] += resized_images[
                i, :, y_src, x_src
            ] * alpha_map[y_indices, x_indices].unsqueeze(0)

    # Clamp the final output to ensure pixel values are in the valid range [0, 1].
    output_image = torch.clamp(output_image, 0, 1)

    return output_image


# def generate(images, comap_yx, dim_lens_lf_yx, num_lens, sensor_size, alpha_map):
#     grid_size = int(math.sqrt(num_lens))
#     idx = torch.arange(grid_size, device=device)
#     grid_i, grid_j = torch.meshgrid(idx, idx, indexing='ij')
#     mapping = ((grid_size - 1 - grid_i) + (grid_size - 1 - grid_j) * grid_size).reshape(-1)

#     images_tensor = torch.stack(images, dim=0).to(device)
#     selected_images = images_tensor[mapping]
#     resized_images = F.interpolate(
#         selected_images,
#         size=(dim_lens_lf_yx[0], dim_lens_lf_yx[1]),
#         mode='bilinear',
#         align_corners=False
#     )

#     output_image = torch.zeros(3, sensor_size, sensor_size, device=device, dtype=torch.float32)

#     for i in range(num_lens):
#         y_coords = comap_yx[i, :, :, 0]
#         x_coords = comap_yx[i, :, :, 1]

#         valid_mask = (y_coords != -1) & (x_coords != -1) & \
#                      (y_coords >= 0) & (y_coords < dim_lens_lf_yx[0]) & \
#                      (x_coords >= 0) & (x_coords < dim_lens_lf_yx[1])

#         # Only process this microlens if there are any valid mapping positions.
#         if valid_mask.any():
#             # Get 2D indices within the sub-image where valid_mask is True.
#             y_indices, x_indices = torch.where(valid_mask)
#             y_src = y_coords[valid_mask].long()
#             x_src = x_coords[valid_mask].long()
#             output_image[:, y_indices, x_indices] += resized_images[i, :, y_src, x_src] * alpha_map[y_indices, x_indices].unsqueeze(0)

#     # Clamp the final output to ensure pixel values are in the valid range [0, 1].
#     output_image = torch.clamp(output_image, 0, 1)

#     return output_image

mode_path = "/home/wl757/multiplexed-pixels/plenoxels/blender_data/lego_gen12/train_multilens_16_black"
base = "59"
NUM_LENS = 16
SENSOR_SIZE = 800
d_lens_sensor = 20  # this is the value to change for more or less multiplexing

comap_yx, dim_lens_lf_yx = get_comap(NUM_LENS, d_lens_sensor, SENSOR_SIZE, SENSOR_SIZE)
alpha_map = (
    torch.from_numpy(generate_alpha_map(comap_yx, NUM_LENS, SENSOR_SIZE, SENSOR_SIZE))
    .float()
    .to(device)
)
plt.imshow(alpha_map.cpu().numpy())
plt.colorbar()
plt.show()
comap_yx = torch.from_numpy(comap_yx).to(device)
images = read_images(NUM_LENS, model_path, base)
# multiplexed_image = generate(images, comap_yx, dim_lens_lf_yx, NUM_LENS, SENSOR_SIZE, alpha_map)
sub_images = generate_sub_images(
    images, comap_yx, dim_lens_lf_yx, NUM_LENS, SENSOR_SIZE, alpha_map
)
for sub_image in sub_images:
    rgb = sub_image[:3]
    alpha = sub_image[3]
    plt.imshow((rgb * alpha).cpu().numpy().transpose(1, 2, 0))
    plt.show()
# plt.imshow(sub_images[5].cpu().numpy().transpose(1, 2, 0))
# plt.show()
# output_image = torch.zeros(3, SENSOR_SIZE, SENSOR_SIZE, device=device, dtype=torch.float32)
# for sub_image in sub_images[5:7]:
#     rgb = sub_image[:3]
#     alpha = sub_image[3]
#     output_image += rgb * alpha
# output_image = torch.clamp(output_image, 0, 1)
# plt.imshow(output_image.cpu().numpy().transpose(1, 2, 0))
# plt.show()

Calculated how the maximum number of rays from microlens that maps to one sensor pixel

Generate multiplexed images, set dir_name to be the saved directory of multiplexed images, and set selected_views to be a list of view index from transform_train.json, set rendered_views_path to be the directory that contains images from the sublens

In [None]:
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

num_lenses_yx = [4, 4]  # [10,10] #[2,1] #[1,1]# [4,4] #[1,1]# [4,4]
MAX_PER_PIXEL = 20
NUM_LENS = num_lenses_yx[0] * num_lenses_yx[1]
sensor_size = 800
d_lens_sensor = 20  # this is the value to change for more or less multiplexing


def get_comap(num_lens, d_lens_sensor, H, W):
    # Verify input and calculate the grid dimensions
    if math.sqrt(num_lens) ** 2 == num_lens:
        num_lenses_yx = [int(math.sqrt(num_lens)), int(math.sqrt(num_lens))]
    else:
        print("Number of sublens should be a square number")
        assert False

    # Calculate microlens dimensions in pixels based on d_lens_sensor
    base_microlens_size = min(H // num_lenses_yx[0], W // num_lenses_yx[1]) // 12
    microlens_height = int(base_microlens_size * d_lens_sensor)
    microlens_height = microlens_height - (
        microlens_height % 2
    )  # Make dimensions even for convenience
    microlens_width = microlens_height  # Keep microlenses square
    comap_yx = -np.ones((num_lens, H, W, 2))

    # Calculate positions for microlenses to distribute from edge to edge
    if num_lenses_yx[0] > 1:
        y_positions = np.linspace(
            microlens_height // 2,  # First lens centered at top edge + half lens height
            H
            - microlens_height
            // 2,  # Last lens centered at bottom edge - half lens height
            num_lenses_yx[0],
        )
    else:
        y_positions = np.array([H // 2])  # If only one row, place it in the center
    if num_lenses_yx[1] > 1:
        x_positions = np.linspace(
            microlens_width // 2,  # First lens centered at left edge + half lens width
            W
            - microlens_width
            // 2,  # Last lens centered at right edge - half lens width
            num_lenses_yx[1],
        )
    else:
        x_positions = np.array([W // 2])  # If only one column, place it in the center

    for i in range(num_lens):
        row, col = i // num_lenses_yx[1], i % num_lenses_yx[1]
        center_y, center_x = int(y_positions[row]), int(x_positions[col])
        start_y = int(max(0, center_y - microlens_height // 2))
        end_y = int(min(H, center_y + microlens_height // 2))
        start_x = int(max(0, center_x - microlens_width // 2))
        end_x = int(min(W, center_x + microlens_width // 2))

        for y in range(start_y, end_y):
            for x in range(start_x, end_x):
                local_y, local_x = y - start_y, x - start_x
                comap_yx[i, y, x, 0] = local_y
                comap_yx[i, y, x, 1] = local_x

    # Return the original dimension as second return value
    dim_lens_lf_yx = [microlens_height, microlens_width]
    return comap_yx, dim_lens_lf_yx


comap_yx, _ = get_comap(16, 20, 800, 800)
print(comap_yx.max())
plt.imshow(comap_yx[0, :, :, 0])
plt.colorbar()
plt.show()
plt.imshow(comap_yx[3, :, :, 1])
plt.colorbar()
plt.show()