In [None]:
# Todo
# 데이터로더에서 하나의 어레이로 합치고, batch 사이즈로 다시 맞추는 코드 만들기
# 레이 프로세싱할때 ray_o, ray_d 다시 고치기
# 데이터 가져올때 np.load로 가져오고 plt으로 나타내기


# Code Initiation

In [None]:
!pip install kornia
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision import transforms as T
from google.colab import drive
from google.colab.patches import cv2_imshow
import cv2
import json
import imageio
from tqdm import tqdm
import torch.nn as nn
from PIL import Image
# Ignore warnings
import warnings
from kornia import create_meshgrid
from datetime import datetime

Collecting kornia
  Downloading kornia-0.7.2-py2.py3-none-any.whl (825 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m825.4/825.4 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting kornia-rs>=0.1.0 (from kornia)
  Downloading kornia_rs-0.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.9.1->kornia)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.9.1->kornia)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.9.1->kornia)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.9

# Helper Functions

In [None]:
def create_folder(folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
        print(f"Folder created: {folder_path}")
    else:
        print(f"Folder already exists: {folder_path}")

def save_model(model, optimizer, path, name, lr, nb_epochs):
  path = path + "models/"
  create_folder(path)
  #current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
  model_save_path = f'{path}nerf_model_{name}_lr{lr}_epochs{nb_epochs}.pth'
  optimizer_save_path = f'{path}nerf_optimizer_{name}_lr{lr}_epochs{nb_epochs}.pth'
  torch.save(model.state_dict(), model_save_path)
  torch.save(model_optimizer.state_dict(), optimizer_save_path)
  print("NeRF Model Saved Successfully")

def load_model(model_path, opt_path):
  model.load_state_dict(torch.load(model_path))
  model_optimizer.load_state_dict(torch.load(opt_path))
  return



# Data Loader

In [None]:
class ChairDataset(Dataset):

    #init function
    def __init__(self, datadir, json_dir = 'transforms_train.json', img_dir = 'train/', batch_size=256, H=400, W=400):
        """
        Arguments:
            datadir (string): Directory to all of the training and testing data
            json_dir (string): Path to the json file with annotations.
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            file_path (string): Path to each of the image
            rotation (f): Rotation value of the each of the image
            transform_matrix (list): 4d array
        """

        #processing the data
        f = open(datadir + json_dir, "r")
        data_json = json.loads(f.read())
        data_json_dict = data_json['frames']
        N = len(data_json_dict)
        file_path = []
        rotation = []
        transform_matrix = []
        camera_angle_x = data_json['camera_angle_x']

        data_collection = [file_path, transform_matrix]

        for i in range(N):
          cnt = 0
          for key in data_json_dict[i]:
            data_collection[cnt].append(data_json_dict[i][key])
            cnt += 1

        rays_origin = []
        rays_direction = []
        pixel_value = []


        for idx in range(N):

          img_name = datadir + file_path[idx].replace('\\', '/')

          trans_mat = transform_matrix[idx]

          rays_o_, rays_d_, target_px_values_ = read_data(img_name, trans_mat, camera_angle_x, H, W)

          rays_origin += [rays_o_]
          rays_direction += [rays_d_]
          pixel_value += [target_px_values_]



        rays_o = torch.cat(rays_origin)
        rays_d = torch.cat(rays_direction)
        target_px_values = torch.cat(pixel_value)
        print(len(rays_o))
        print(len(target_px_values))



        #Chair Dataset Variables
        self.size = rays_o.shape[0]
        self.H = H
        self.W = W
        self.batch_size = batch_size
        self.rays_o = rays_o
        self.rays_d = rays_d
        self.target_px_values = target_px_values


    def __len__(self):
        return (self.size + self.batch_size - 1) // self.batch_size

    def __H__(self):
        return self.H
    def __W__(self):
        return self.W

    # Get Item: returns a sample which contains image (file path to each image), rotation (int), and transform matrix ((4, 4) ndarray)
    def __getitem__(self, idx):
        start = idx * self.batch_size
        end = min(start + self.batch_size, self.size)

        batch_rays_o = self.rays_o[start:end]
        batch_rays_d = self.rays_d[start:end]
        batch_px = self.target_px_values[start:end]

        # if (np.any(batch_px > 1)):
        #   print("Not Normalized")

        sample = {'rays_o': batch_rays_o,
                  'rays_d': batch_rays_d,
                  'target_px_values': batch_px}

        return sample


In [None]:
def read_data(image_dir, pose, camera_angle_x, H, W):
    to_tensor = T.ToTensor()

    focal = 0.5*800/np.tan(0.5*camera_angle_x) # original focal length
                                                                  # when W=800

    focal *= W/800 # modify focal length to match size self.img_wh

    # bounds, common for all scenes
    near = 2.0
    far = 6.0
    bounds = np.array([near, far])

    # ray directions for all pixels, same for all images (same H, W, focal)
    directions = \
        get_ray_directions(H, W, focal) # (h, w, 3)


    pose = np.array(pose)[:3, :4]

    c2w = torch.FloatTensor(pose)


    img = Image.open(image_dir)
    img = img.resize((W, H), Image.LANCZOS)
    if (np.shape(img)[0]==4):
      img = to_tensor(img) # (4, h, w)
      img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
      img = img[:, :3]*img[:, -1:] + (1-img[:, -1:]) # blend A to RGB
    else :
      img = to_tensor(img) # (3, h, w)
      img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB


    rays_o, rays_d = get_rays(directions, c2w)


    return rays_o, rays_d, img



def get_ray_directions(H, W, focal):

    grid = create_meshgrid(H, W, normalized_coordinates=False)[0]
    i, j = grid.unbind(-1)
    directions = \
        torch.stack([(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) # (H, W, 3)

    return directions


def get_rays(directions, c2w):

    # Rotate ray directions from camera coordinate to the world coordinate
    rays_d = directions @ c2w[:, :3].T # (H, W, 3)
    rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
    # The origin of all rays is the camera origin in world coordinate
    rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3)

    rays_d = rays_d.view(-1, 3)
    rays_o = rays_o.view(-1, 3)

    return rays_o, rays_d

# Ray Process

In [None]:
def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)


def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device

    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)

    # Compute the 3D points along each ray
    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)   # [batch_size, nb_bins, 3]
    # Expand the ray_directions tensor to match the shape of x
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    # Compute the pixel values as a weighted sum of colors along each ray
    c = (weights * colors).sum(dim=1)
    weight_sum = weights.sum(-1).sum(-1)  # Regularization for white background
    return c + 1 - weight_sum.unsqueeze(-1)

# NeRF Model

In [None]:
class NerfModel(nn.Module):
    def __init__(self, embedding_dim_pos=10, embedding_dim_direction=4, hidden_dim=128):
        super(NerfModel, self).__init__()

        self.block1 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        # density estimation
        self.block2 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + hidden_dim + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim + 1), )
        # color estimation
        self.block3 = nn.Sequential(nn.Linear(embedding_dim_direction * 6 + hidden_dim + 3, hidden_dim // 2), nn.ReLU(), )
        self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 3), nn.Sigmoid(), )

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.relu = nn.ReLU()

    @staticmethod
    def positional_encoding(x, L):
        out = [x]
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        return torch.cat(out, dim=1)

    def forward(self, o, d):
        emb_x = self.positional_encoding(o, self.embedding_dim_pos) # emb_x: [batch_size, embedding_dim_pos * 6]
        emb_d = self.positional_encoding(d, self.embedding_dim_direction) # emb_d: [batch_size, embedding_dim_direction * 6]
        h = self.block1(emb_x) # h: [batch_size, hidden_dim]
        tmp = self.block2(torch.cat((h, emb_x), dim=1)) # tmp: [batch_size, hidden_dim + 1]
        h, sigma = tmp[:, :-1], self.relu(tmp[:, -1]) # h: [batch_size, hidden_dim], sigma: [batch_size]
        h = self.block3(torch.cat((h, emb_d), dim=1)) # h: [batch_size, hidden_dim // 2]
        c = self.block4(h) # c: [batch_size, 3]
        return c, sigma

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

# NeRF Test


In [None]:
@torch.no_grad()
def test(hn, hf, dataset, chunk_size=10, nb_bins=192, H=400, W=400, epoch_idx = 0,
         output="/content/drive/MyDrive/NeRF_Data_Repository/output", lr=5e-4):
    """
    Args:
        hn: near plane distance
        hf: far plane distance
        dataset: dataset to render
        chunk_size (int, optional): chunk size for memory efficiency. Defaults to 10.
        img_index (int, optional): image index to render. Defaults to 0.
        nb_bins (int, optional): number of bins for density estimation. Defaults to 192.
        H (int, optional): image height. Defaults to 400.
        W (int, optional): image width. Defaults to 400.

    Returns:
        None: None
    """

    idx = 0
    for batch in dataset:
      if idx == len(dataset):
        break

      ray_origins = batch['rays_o']
      ray_directions = batch['rays_d']

      data = []   # list of regenerated pixel values
      for i in range(int(np.ceil(H / chunk_size))):   # iterate over chunks
          # Get chunk of rays
          ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
          ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
          regenerated_px_values = render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
          if (torch.any(regenerated_px_values > 1)):
            print("Test Not Normalized")
          data.append(regenerated_px_values)
      img = torch.cat(data).data.cpu().numpy().reshape(H, W, 3)

      if np.any(img > 1):
        print("error")

      if (idx % 25 == 0):
        plt.figure()
        plt.title("Test")
        plt.imshow(img)
        file_name = f'image_{idx}_epoch_{epoch_idx}_lr_{lr}.png'
        save_path = f'{output}/{file_name}'
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()
      idx += 1

# NeRF Train

In [None]:
def train(nerf_model, optimizer, scheduler, data_loader, test_loader, device='cuda', hn=0, hf=1, nb_epochs=int(1e5),
          nb_bins=192, H=400, W=400, directory="/content/drive/MyDrive/NeRF_Data_Repository/", lr=5e-4):
    training_loss = []
    epoch_idx = 0
    output_dir = directory + "output"
    create_folder(output_dir)
    for _ in tqdm(range(nb_epochs)):

        epoch_loss = []
        i = 0
        data = []
        img_idx = 0
        for batch in data_loader:
            ray_origins = batch['rays_o'].to(device)
            ray_directions = batch['rays_d'].to(device)
            ground_truth_px_values = batch['target_px_values'].to(device)
            if (torch.any(ground_truth_px_values > 1)):
              print("Input Not Normalized")
            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
            data.append(regenerated_px_values)

            loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss)
            if (len(data_loader) == i):
              break
            i += 1

        scheduler.step()
        training_loss.append(loss)


        test(hn, hf, test_loader, nb_bins=nb_bins, H=H, W=W, epoch_idx = epoch_idx, output=output_dir, lr = lr)
        epoch_idx += 1
    return training_loss

# Variables

In [None]:
# batch_size, nb_bins: Increase for precision, but takes longer


height = 200
width = 200
batch_size = 200
hidden_dim = 256
nb_bins = 200
nb_epochs = 20
learning_rate = 3e-5
near_plane = 20 - 16
far_plane = 20 + 16
folder_name = "tomato"
#whiteflower
#tomato_tree

In [None]:
drive.mount('/content/drive')
data_dir = "/content/drive/MyDrive/NeRF_Data_Repository/" + folder_name + "/"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
train_dir = folder_name + "_dataset_train/"
test_dir = folder_name + "_dataset_test/"
train_loader = ChairDataset(datadir = data_dir + train_dir, batch_size = batch_size, img_dir = "train/", H = height, W = width)
test_loader = ChairDataset(datadir = data_dir + test_dir, batch_size = height*width, img_dir = "train/", H = height, W = width)

4000000
4000000
4000000
4000000


# Main

In [None]:
device = 'cuda'

model = NerfModel(hidden_dim=hidden_dim).to(device)
model.apply(init_weights)
model_optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)
losses = train(model, model_optimizer, scheduler, train_loader, test_loader, nb_epochs=nb_epochs, device=device, hn= near_plane, hf= far_plane, nb_bins=nb_bins, H=height,W=width, directory = data_dir, lr=learning_rate)
print(losses)
#plt.plot(np.arange(nb_epochs), losses)

Folder already exists: /content/drive/MyDrive/NeRF_Data_Repository/tomato/output


100%|██████████| 20/20 [2:00:20<00:00, 361.05s/it]


[tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackward0>), tensor(0., device='cuda:0', grad_fn=<SumBackw

# Trash

In [None]:
# def read_data(image_dir, pose, camera_angle_x, H, W):
#     # with open(os.path.join(self.root_dir,
#     #                         f"transforms_{self.split}.json"), 'r') as f:
#     #     self.meta = json.load(f)


#     focal = 0.5*800/np.tan(0.5*camera_angle_x) # original focal length
#                                                                   # when W=800

#     focal *= W/800 # modify focal length to match size self.img_wh

#     # bounds, common for all scenes
#     near = 2.0
#     far = 6.0
#     bounds = np.array([near, far])

#     # ray directions for all pixels, same for all images (same H, W, focal)
#     directions = \
#         get_ray_directions(H, W, focal) # (h, w, 3)

#     pose = np.array(frame['transform_matrix'])[:3, :4]
#     poses += [pose]
#     c2w = torch.FloatTensor(pose)


#     img = Image.open(image_dir)
#     img = img.resize((W, H), Image.LANCZOS)
#     img = transform(img) # (4, h, w)
#     img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
#     img = img[:, :3]*img[:, -1:] + (1-img[:, -1:]) # blend A to RGB
#     all_rgbs += [img]

#     rays_o, rays_d = get_rays(directions, c2w)


#     if split == 'train': # create buffer of all rays and rgb data
#         image_paths = []
#         poses = []
#         all_rays = []
#         all_rgbs = []
#         for frame in meta['frames']:
#             pose = np.array(frame['transform_matrix'])[:3, :4]
#             poses += [pose]
#             c2w = torch.FloatTensor(pose)

#             image_path = os.path.join(root_dir, f"{frame['file_path']}.png")
#             image_paths += [image_path]
#             img = Image.open(image_path)
#             img = img.resize(img_wh, Image.LANCZOS)
#             img = transform(img) # (4, h, w)
#             img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
#             img = img[:, :3]*img[:, -1:] + (1-img[:, -1:]) # blend A to RGB
#             all_rgbs += [img]

#             rays_o, rays_d = get_rays(directions, c2w) # both (h*w, 3)

#             all_rays += [torch.cat([rays_o, rays_d,
#                                           near*torch.ones_like(rays_o[:, :1]),
#                                           far*torch.ones_like(rays_o[:, :1])],
#                                           1)] # (h*w, 8)

#         all_rays = torch.cat(all_rays, 0) # (len(self.meta['frames])*h*w, 3)
#         all_rgbs = torch.cat(all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)


# def get_ray_directions(H, W, focal):
#     """
#     Get ray directions for all pixels in camera coordinate.
#     Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
#                ray-tracing-generating-camera-rays/standard-coordinate-systems

#     Inputs:
#         H, W, focal: image height, width and focal length

#     Outputs:
#         directions: (H, W, 3), the direction of the rays in camera coordinate
#     """
#     grid = create_meshgrid(H, W, normalized_coordinates=False)[0]
#     i, j = grid.unbind(-1)
#     # the direction here is without +0.5 pixel centering as calibration is not so accurate
#     # see https://github.com/bmild/nerf/issues/24
#     directions = \
#         torch.stack([(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) # (H, W, 3)

#     return directions


# def get_rays(directions, c2w):
#     """
#     Get ray origin and normalized directions in world coordinate for all pixels in one image.
#     Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
#                ray-tracing-generating-camera-rays/standard-coordinate-systems

#     Inputs:
#         directions: (H, W, 3) precomputed ray directions in camera coordinate
#         c2w: (3, 4) transformation matrix from camera coordinate to world coordinate

#     Outputs:
#         rays_o: (H*W, 3), the origin of the rays in world coordinate
#         rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
#     """
#     # Rotate ray directions from camera coordinate to the world coordinate
#     rays_d = directions @ c2w[:, :3].T # (H, W, 3)
#     rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
#     # The origin of all rays is the camera origin in world coordinate
#     rays_o = c2w[:, 3].expand(rays_d.shape) # (H, W, 3)

#     rays_d = rays_d.view(-1, 3)
#     rays_o = rays_o.view(-1, 3)

#     return rays_o, rays_d

In [None]:
# from PIL import Image

# def generate_ray_directions(K_inv, width, height):
#     # Meshgrid for pixel coordinates
#     x = np.linspace(0, width - 1, width)
#     y = np.linspace(0, height - 1, height)
#     x, y = np.meshgrid(x, y)
#     # Homogeneous coordinates of pixels
#     pixels = np.stack([x.flatten(), y.flatten(), np.ones_like(x.flatten())], axis=-1)
#     # Transform to camera space
#     ray_dirs_camera = K_inv @ pixels.T
#     ray_dirs_camera = ray_dirs_camera[:3, :].T  # Remove homogeneous coordinate
#     # Normalize directions
#     norms = np.linalg.norm(ray_dirs_camera, axis=1, keepdims=True)
#     ray_dirs_camera /= norms
#     return ray_dirs_camera


# def load_ground_truth_image(image_path):
#     # Load an image file as ground truth
#     with Image.open(image_path) as img:
#         white_background = Image.new("RGB", img.size, (255, 255, 255))
#         rgb_image = Image.alpha_composite(white_background.convert("RGBA"), img).convert("RGB")
#     return rgb_image

# def get_ray(image_dir, pose, camera_angle_x, H, W):
#   trans_mat = np.array(pose)

#   f_x = (W / 2) / np.tan(camera_angle_x / 2)
#   f_y = f_x
#   c_x = W / 2
#   c_y = H / 2
#   # Intrinsic camera matrix
#   K = np.array([
#       [f_x, 0, c_x],
#       [0, f_y, c_y],
#       [0, 0, 1]
#   ])

#   K_inv = np.linalg.inv(K)

#   ray_directions_camera = generate_ray_directions(K_inv, W, H)
#   rotation_matrix = trans_mat[:3, :3]
#   rays_d = ray_directions_camera @ rotation_matrix.T
#   rays_o = np.tile(trans_mat[:3, 3], (W * H, 1))
#   ground_truth_px_values = load_ground_truth_image(image_dir)
#   reshaped_px_values = np.reshape(ground_truth_px_values, (H*W, 3)) / 255




#   return rays_o, rays_d, reshaped_px_values

In [None]:
# def load_ground_truth_image(image_path):
#     # Load an image file as ground truth
#     with Image.open(image_path) as img:
#         white_background = Image.new("RGB", img.size, (255, 255, 255))
#         rgb_image = Image.alpha_composite(white_background.convert("RGBA"), img).convert("RGB")
#     return rgb_image

# def get_ray(image_dir, pose, camera_angle_x, H, W):
#   trans_mat = np.array(pose)
#   f_x = (W / 2) / np.tan(camera_angle_x / 2)

#   x, y = np.meshgrid(
#             np.arange(W, dtype=np.float32),  # X-Axis (columns)
#             np.arange(H, dtype=np.float32),  # Y-Axis (rows)
#             indexing='xy')
#   homogeneous_directions = np.stack(
#             [(x - W * 0.5) / f_x,
#             -(y - H * 0.5) / f_x,
#             -np.ones_like(x)],
#             axis=-1)

#   c2w = trans_mat

#   rays_d = homogeneous_directions @ c2w[:3, :3].T
#   rays_d = rays_d / np.linalg.norm(rays_d, axis=-1, keepdims=True)
#   rays_o = np.tile(c2w[:3, 3], (H, W, 1))


#   rays_d = rays_d.reshape(-1, 3)
#   rays_o = rays_o.reshape(-1, 3)



#   ground_truth_px_values = load_ground_truth_image(image_dir)
#   reshaped_px_values = np.reshape(ground_truth_px_values, (H*W, 3)) / 255




#   return rays_o, rays_d, reshaped_px_values

# Acknowledgements