# Nerfstudio To D-NeRF dataset

Set Up:

1. Installl Nerfstudio and activate environment
2. Run this notebook with the nerfstudio conda environment

Instructions:

1. Update configuration relative to desired parameters
2. Run the method

#### Notes on D-NeRF dataset:

In the nerfstudio DNeRF parser, `camera_angle_x` parameter (found in the transforms files) has a $1/tanh(0.5*x)$ relationship with the focal length. While I did mess around with this in [Desmos](https://www.desmos.com/calculator/xw0lodoghb) and initially selected a high $x=4.$ value. It didn't play well with kplanes (tested with scene contraction and varying near and far plane positions). In the end $x=0.6$ worked best despite how counter intuitive this is... The parameter is important for good rendering but is not directly recoverable from the nerfstudio dataset (it might be I just haven't looked into it).

Otherwise the `rotation` parameter found for each frame doesn't seem to have any impact on performance (at least in the tests I ran with K-Planes, **this maybe different for other models!!!**).

---
# Configuration
---

Information:
1. `nerfstudio_fp` is the path **to the folder** containing `transforms.json`
2. `output_fp` is the path to the folder you wish to write `transforms_train.json`, `transforms_test.json` and `transforms_val.json`
3. `downscale_images_fp` is the folder name inside of you nerfstudio folder containing the downscaled images if you would rather save these than the original images. As dnerf format doesn't consider downscaled images, this will allow you to use them instead.
4. `method` declares the way we recover the time values:
    - `'exhaustive'`    : match image to frame (**super slow**)
    - `'linear'`    : assign time given image name-index (**fast**) (e.g. image w/name `frame_{i}.png` is going to be at i/n time where n is the number of images and `i` is the index; this is the image-naming format used by nerfstudio-colamp process)

In [1]:
nerfstudio_fp ='path/to/folder/containing/tranform/file'
video_fp = 'path/to/video.mp4'
output_fp = 'path/to/write/transform/to'

downscale_images_fp = 'path/to/folder/containing/downscalled/images'

method = 'exhaustive'

# handler(nerfstudio_fp, output_fp, video_fp, downscale_images_fp, method)

In [9]:
nerfstudio_fp ='data/boat_colmap/'
video_fp = 'data/boat/boat.mp4'
output_fp = 'data/boat_exhaustive/'

downscale_images_fp = 'images/'

method = 'exhaustive'

handler(nerfstudio_fp, output_fp, video_fp, downscale_images_fp, method, rotation=0.0, camera_angle_x=0.9)

Total number of frames to process 534 
 Total number of images to process 539
Running


0it [00:00, ?it/s]

Image 4 has no match: consider lowering SSIM threshold
Image 5 has no match: consider lowering SSIM threshold
Image 6 has no match: consider lowering SSIM threshold
Image 7 has no match: consider lowering SSIM threshold
Image 8 has no match: consider lowering SSIM threshold
Image 9 has no match: consider lowering SSIM threshold
Image 10 has no match: consider lowering SSIM threshold
Image 11 has no match: consider lowering SSIM threshold
Image 12 has no match: consider lowering SSIM threshold
Image 13 has no match: consider lowering SSIM threshold
Image 14 has no match: consider lowering SSIM threshold
Image 15 has no match: consider lowering SSIM threshold
Image 16 has no match: consider lowering SSIM threshold
Image 17 has no match: consider lowering SSIM threshold
Image 18 has no match: consider lowering SSIM threshold
Image 19 has no match: consider lowering SSIM threshold
Image 20 has no match: consider lowering SSIM threshold
Image 21 has no match: consider lowering SSIM threshol

TypeError: list indices must be integers or slices, not dict

---
# View the code
---

Information:
1. Import dependencies. *Make sure nerfstudio has been downloaded*
2. View the functions
3. Change the functions (optional)
4. Run the functions

In [2]:
# Import
import os
import json
from pathlib import Path
import random
import shutil

import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from tqdm.notebook import tqdm

from utils_ import *

### Exhaustive SSIM Time Search

**Args:**
- d_fp, v_fp, img_fp: Path, previously discussed
- transfors_fp: Path, path to transforms file

**Notes:**
1. Exhaustive search matches each image (e.g. `frame_0000.png`, `frame_0001.png`, ... ) to each frame in a video. 
2. Each image is a frame with png compression so direct image to frame comparison isn't possible
3. Instead we compare w/ SSIM.
4. This means:

    a. Overlapping frames (such as monocular stationary camera with negligible dynamic motion) will have the same SSIM score and so we will get several frames which match
        
    b. We select the earliest occuring frame match as the time of the png image
    
    c. We accept that this may not always be the case so we add a threshold, whereby we search for the earliest match where SSIM > 0.95 , when this is not the case max(SSIM) > 0.9 is selected.
    
    d. Theoretically, this shouldn't be an issue for NeRF as SSIM threshold is high so should be negligible during NeRF evaluation
    

In [3]:
def exhaustive(d_fp, o_fp, v_fp, img_fp, transforms_fp, shuffle:bool=True):
    assert not os.path.exists(o_fp), 'Folder already exists, delete folder to run'

    os.makedirs(o_fp)
    with open(transforms_fp) as fp:
        contents = fp.read()
    transforms = json.loads(contents)
    img_frames = transforms['frames'] # Directly access frame data

    # Initialise opencv video object
    video = cv2.VideoCapture(str(v_fp))
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    print(f'Total number of frames to process {total_frames} \n Total number of images to process {len(img_frames)}')


    video_frame_counter = 0
    matches = []

    # Sort img_frames by filepath name
    filenum = []
    for img in img_frames:
        num = img['file_path'].split('/')[-1].split('.')[-2].split('_')[-1]
        filenum.append(int(num))
    img_frames = [img for _, img in sorted(zip(filenum, img_frames))]
    
    print('Running')
    iterator = tqdm(enumerate(img_frames))
    # loop through each image in our colmap dataset
    for idx, img in iterator:
        fp = d_fp / img['file_path']
        image = cv2.imread(str(fp), cv2.IMREAD_GRAYSCALE) # load image greyscale

        SSIM = {
            'max': 0.,
            'idxs':[]
        }
        for idx_video in range(video_frame_counter, total_frames):
            # Fetch frame from video
            video.set(cv2.CAP_PROP_FRAME_COUNT, idx_video)
            ret, frame = video.read()
            if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) # load frame greyscale
            else: break

            # Get SSIM
            ssim_res = ssim(image, frame)

            # Process SSIM
            if ssim_res > 0.95: # if we meet ideal match
                SSIM['max'] = 1.
                SSIM['idxs'] = [idx_video] 
                break
            elif ssim_res > SSIM['max']: # if we have a new max
                SSIM['max'] = ssim_res
                SSIM['idxs'] = [idx_video]
            elif ssim_res == SSIM['max']: # if we have the same max
                SSIM['idxs'].append(idx_video)
        

        if SSIM['max'] < 0.9:
            print(f'Image {idx} has no match: consider lowering SSIM threshold')
        else:
            idx_video = min(SSIM['idxs'])
            matches.append({
                "frame" : float(idx_video/total_frames),
                "image" : idx
                })

    # shuffle data
    if shuffle == True:
        random.shuffle(matches)

    # train split
    train_split = 0.9
    train_idx = int(train_split * len(matches))
    train_data = matches[:train_idx]
    # test split (on remaining data)
    test_split = 0.9
    test_index = train_idx + int(test_split * (len(matches) - len(train_data)))
    test_data =  matches[train_idx : test_index]
    # val split
    val_data = matches[test_index:]


    # Construct transform files
    local_properties = {
        "roation": 0.0,
    }
    data = [train_data, test_data, val_data]
    for idx, d in enumerate(data):
        file_ = {
            "camera_angle_x": 0.6, # Seems to work best on the toy datasets I used
            "frames":[]
        }
        # file_path
        if idx == 0: file_path = Path(o_fp) / 'transforms_train.json'
        elif idx == 1: file_path = Path(o_fp) / 'transforms_test.json'
        elif idx == 2: file_path = Path(o_fp) / 'transforms_val.json'
        
        for match_idx in d:
            img_data = img_frames[match_idx]
            time = float(img_data['idx'] / total_frames)
            fname = img_data['file_path'].split('/')[-1]
            
            file_["frames"].append({
            "file_path":f'./train/{fname}',
            "rotation": local_properties['rotation'],
            "time":time,
            "transform_matrix":img_data['transform_matrix']
        })



### Blind Linear Time Search

**Args:**
- Same as before

**Notes:**
- Order collected images by frame number and linearly assign time value between 0 and 1.
- This will make the motion linear with colmap extraction so it may not be ideal.

In [4]:
def linear(d_fp, o_fp, v_fp, img_fp, transforms_fp, shuffle:bool=True, rotation:float=0.0, camera_angle_x:float=0.0):
    assert not os.path.exists(o_fp), 'Folder already exists, delete folder to run'

    os.makedirs(o_fp)
    os.makedirs(o_fp+'/train')
    os.makedirs(o_fp+'/test')
    os.makedirs(o_fp+'/val')
    with open(transforms_fp) as fp:
        contents = fp.read()
    transforms = json.loads(contents)
    img_frames = transforms['frames'] # Directly access frame data
    print(f'Total number of images to process {len(img_frames)}')

    local_properties = {
        "rotation":rotation,
    }

    frames = []
    for idx, img in enumerate(img_frames):
        fname = img['file_path'].split('/')[-1].split('_')[-1].split('.')[0]

        frames.append({
            "file_path":f'./train/frame_{fname}',
            "rotation": local_properties['rotation'],
            "time":int(fname)/len(img_frames),
            "transform_matrix":img['transform_matrix']
        })
    
    # shuffle data
    if shuffle == True:
        random.shuffle(frames)

    # train split
    train_split = 0.9
    train_idx = int(train_split * len(frames))
    train_data = frames[:train_idx]
    # test split (on remaining data)
    test_split = 0.9
    test_index = train_idx + int(test_split * (len(frames) - len(train_data)))
    test_data =  frames[train_idx : test_index]
    # val split
    val_data = frames[test_index:]

    print(f' Training dataset: {len(train_data)} | Testing dataset: {len(test_data)}')
    data = [train_data, test_data, val_data]
    for idx, d in enumerate(data):
        file_ = {
            "camera_angle_x": camera_angle_x,
            "frames":[]
        }
        # file_path
        if idx == 0: file_path = Path(o_fp) / 'transforms_train.json'
        elif idx == 1: file_path = Path(o_fp) / 'transforms_test.json'
        elif idx == 2: file_path = Path(o_fp) / 'transforms_val.json'
        
        for frame in d:
            file_['frames'].append(frame)

            if idx == 0:
                destination = Path(o_fp) / 'train' / (frame['file_path'].split('/')[-1] + '.png')
                source = img_fp / (frame['file_path'].split('/')[-1] + '.png')
                shutil.copyfile(source, destination)
            elif idx == 1:
                destination = Path(o_fp) / 'test' / (frame['file_path'].split('/')[-1] + '.png')
                source = img_fp / (frame['file_path'].split('/')[-1] + '.png')
                shutil.copyfile(source, destination)
            elif idx == 2:
                destination = Path(o_fp) / 'val' / (frame['file_path'].split('/')[-1] + '.png')
                source = img_fp / (frame['file_path'].split('/')[-1] + '.png')
                shutil.copyfile(source, destination)
        
        with open(file_path, 'w') as fp:
            json.dump(file_, fp)   

    

### Handler for nerfstudio2dnerf

**Args:**
- `d_fp`, Path, path to `transforms.json` **folder**
- `o_fp`, Path, path to output folder
- `v_fp`, Path, path to video
- `img_fp`, Path, path to image folder

In [7]:
def handler(d_fp, o_fp, v_fp, img_fp, meth, rotation:float=0.0, camera_angle_x:float=0.0):
    d_fp = Path(d_fp)
    v_fp = Path(v_fp)
    img_fp = d_fp / img_fp

    transforms_fp = d_fp/'transforms.json'

    # Sanity Checks
    pathchecks([d_fp, v_fp, img_fp])
    folderchecks([d_fp, img_fp])
    
    # meth = 'linear'
    # Handle exhaustive method
    if meth == 'exhaustive':
        exhaustive(d_fp, o_fp, v_fp, img_fp,transforms_fp)    
    elif meth == 'linear':
        linear(d_fp, o_fp, v_fp, img_fp,transforms_fp, rotation=rotation, camera_angle_x=camera_angle_x)    