In [None]:
!pip install torch
!pip install numpy
!pip install matplotlib
!pip install tqdm



In [None]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader


class NerfModel(nn.Module):
    def __init__(self, embedding_dim_pos=10, embedding_dim_direction=4, hidden_dim=128):
        super(NerfModel, self).__init__()

        # The positional encoding of the input location is passed through 8 fully-connected ReLU layers.
        # The input to the first block is the position encoded vector concatenated with the raw 3D coordinates.
        self.block1 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
        # A skip connection from the input to the fifth layer’s activation.
        # Outputs the volume density (sigma) and a feature vector.
        self.block2 = nn.Sequential(nn.Linear(embedding_dim_pos * 6 + hidden_dim + 3, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
                                    nn.Linear(hidden_dim, hidden_dim + 1), )
        # The feature vector is concatenated with the positional encoding of the input viewing direction.
        # This is processed by an additional fully-connected ReLU layer with 128 channels.
        self.block3 = nn.Sequential(nn.Linear(embedding_dim_direction * 6 + hidden_dim + 3, hidden_dim // 2), nn.ReLU(), )
        self.block4 = nn.Sequential(nn.Linear(hidden_dim // 2, 3), nn.Sigmoid(), )

        # Store the dimensions of the embeddings for access during forward pass.
        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction
        self.relu = nn.ReLU()

    # Positional encoding helps the model to understand the relative or absolute position of the data points.
    @staticmethod
    def positional_encoding(x, L):
        out = [x]
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        return torch.cat(out, dim=1)
    # The forward pass of the model, where the actual computation on the input happens.
    # It takes object positions 'o' and directions 'd', processes them through the network, and returns color 'c' and density 'sigma'.
    def forward(self, o, d):
        emb_x = self.positional_encoding(o, self.embedding_dim_pos)
        emb_d = self.positional_encoding(d, self.embedding_dim_direction)
        h = self.block1(emb_x)
        tmp = self.block2(torch.cat((h, emb_x), dim=1))
        h, sigma = tmp[:, :-1], self.relu(tmp[:, -1])
        h = self.block3(torch.cat((h, emb_d), dim=1))
        c = self.block4(h)
        return c, sigma

@torch.no_grad()
def test(hn, hf, dataset, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):

    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
        regenerated_px_values = render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values)
    img = torch.cat(data).data.cpu().numpy().reshape(H, W, 3)

    plt.figure()
    plt.imshow(img)
    plt.savefig(f'/content/drive/MyDrive/Colab Notebooks/Nerf_Imp/results/img_{img_index}.png', bbox_inches='tight')
    plt.close()

# Compute the accumulated transmittance, which represents how much light is not blocked by the medium.
# This is essential for volume rendering, where you accumulate color along a ray through a semi-transparent medium.
def compute_accumulated_transmittance(alphas):
    # Cumulative product of transmittance across sampled points along the ray.
    # This is necessary to calculate how much light reaches each point.
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)


# Render the rays to generate the final image.
# This function takes in ray origins and directions, as well as near and far bounds and number of bins to sample along the ray.
def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device

    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    # Sample points along each ray for integral approximation.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)
    t = lower + (upper - lower) * u

    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e10], device=device).expand(ray_origins.shape[0], 1)), -1)
    # Calculate differences between adjacent samples to compute differential transmittance.

    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)   # Calculate 3D points along the rays

    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))   # Get colors and densities at sample points
    colors = colors.reshape(x.shape) # Reshape to match the input ray shapes
    sigma = sigma.reshape(x.shape[:-1]) # Reshape to match the input ray shapes, sans the color dimension

    # Compute the opacity for each sample point along the ray based on the volume density (sigma) and the distance between samples (delta).
    alpha = 1 - torch.exp(-sigma * delta)

    # Calculate the weights for each sample point. This is done by computing the accumulated transmittance (which accounts for the
    # amount of light not blocked by preceding sample points) and then multiplying by the alpha value for each point.
    # This effectively gives us a weight that represents the contribution of each point's color to the final pixel color.
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)

    # Sum the weighted colors along each ray to get the final color for each pixel. The weights ensure that colors from points
   # closer to the camera or in denser parts of the volume have a bigger impact on the final color.
    c = (weights * colors).sum(dim=1)
    # Compute the sum of the weights for each ray. This is used to adjust the final color based on the amount of light that
    # reaches the camera without being absorbed or scattered. If all light is absorbed, weight_sum will be close to 0; if
    # all light passes through, weight_sum will be close to 1.
    weight_sum = weights.sum(-1).sum(-1)

    # Return the final color for each ray. The expression "c + 1 - weight_sum.unsqueeze(-1)" adjusts the final color by adding
    # a small amount of ambient light to ensure that completely unlit areas have a minimal brightness instead of being pitch black.
    # This can help prevent areas that are fully occluded or in shadow from being entirely dark, improving visual quality.
    return c + 1 - weight_sum.unsqueeze(-1)


def train(nerf_model, optimizer, scheduler, data_loader, device='cpu', hn=0, hf=1, nb_epochs=int(1e5),
          nb_bins=192, H=400, W=400):
    training_loss = []
    for _ in tqdm(range(nb_epochs)):
        for batch in data_loader:
            ray_origins = batch[:, :3].to(device)
            ray_directions = batch[:, 3:6].to(device)
            ground_truth_px_values = batch[:, 6:].to(device)

            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
            loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss.append(loss.item())
        scheduler.step()

        for img_index in range(200):
            test(hn, hf, testing_dataset, img_index=img_index, nb_bins=nb_bins, H=H, W=W)
    return training_loss


if __name__ == '__main__':
    device = 'cuda'

    training_dataset = torch.from_numpy(np.load('/content/drive/MyDrive/Colab Notebooks/Nerf_Imp/training_data.pkl', allow_pickle=True))
    testing_dataset = torch.from_numpy(np.load('/content/drive/MyDrive/Colab Notebooks/Nerf_Imp/testing_data.pkl', allow_pickle=True))
    model = NerfModel(hidden_dim=256).to(device)
    model_optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(model_optimizer, milestones=[2, 4, 8], gamma=0.5)
    data_loader = DataLoader(training_dataset, batch_size=1024, shuffle=True, num_workers=4, pin_memory=True)
    train(model, model_optimizer, scheduler, data_loader, nb_epochs=16, device=device, hn=2, hf=6, nb_bins=192, H=400,
          W=400)

 44%|████▍     | 7/16 [4:33:39<6:09:00, 2460.06s/it]