# Jupyter notebook for image preprocessing for colmap
this notebook is for testing before modularizing and cleaning up the code

In [None]:
from projectaria_tools.projects.aea import AriaEverydayActivitiesDataPathsProvider, AriaEverydayActivitiesDataProvider
from projectaria_tools.core import data_provider, calibration
from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions
from projectaria_tools.core.stream_id import RecordableTypeId, StreamId

Create data provider

In [None]:
data_path = "./../data/loc1_script2_seq1_rec1_10s_sample/recording.vrs"
frame_step = 1
output_dir = "./../data/loc1_script2_seq1_rec1_10s_sample/images"

### Image undistortion 
Given a rgb video stream from a given VRS file, this will undistort and split it into images.

In [None]:
from PIL import Image
import numpy as np

def undistort(provider):
    sensor_name = "camera-rgb"
    stream_id = provider.get_stream_id_from_label(sensor_name)
    device_calib = provider.get_device_calibration()
    src_calib = device_calib.get_camera_calib(sensor_name)
    input_calib_width = src_calib.get_image_size()[0]
    input_calib_height = src_calib.get_image_size()[1]
    input_calib_focal = src_calib.get_focal_lengths()[0]
    dst_calib = calibration.get_linear_camera_calibration(int(input_calib_width)-120, int(input_calib_height)-120, input_calib_focal, sensor_name)
    
    preimages = []
    images = []
    for i in range(0, provider.get_num_data(stream_id)):
        raw_image = provider.get_image_data_by_index(stream_id, i)[0].to_numpy_array()
        undistorted_image = calibration.distort_by_calibration(raw_image, dst_calib, src_calib)
        
        preimages.append(Image.fromarray(raw_image))
        im = Image.fromarray(np.rot90(undistorted_image, k=3))
        images.append(im)
    return images
    
# output = undistort(provider)

## Save Images to a directory

In [None]:
from tqdm import tqdm
import os
def save_frames(images, output_dir, frame_step=1):
    count = 0
    os.makedirs(output_dir, exist_ok=True)
    with tqdm(total=len(images)) as pbar:
        while count < len(images):
            if(count % frame_step == 0):
                image = images[count]
                image.save(f"{output_dir}/frame_{count:04d}.jpg")
            
            count += 1
            pbar.update(1)
# save_frames(output, "./output", 1)

## Combining these two into one function to be more efficient

In [None]:
def process_vrs(vrs_path, output_dir, frame_step=1):
    provider = data_provider.create_vrs_data_provider(vrs_path)
    
    sensor_name = "camera-rgb"
    stream_id = provider.get_stream_id_from_label(sensor_name)
    device_calib = provider.get_device_calibration()
    src_calib = device_calib.get_camera_calib(sensor_name)
    input_calib_width = src_calib.get_image_size()[0]
    input_calib_height = src_calib.get_image_size()[1]
    input_calib_focal = src_calib.get_focal_lengths()[0]
    dst_calib = calibration.get_linear_camera_calibration(int(input_calib_width)-120, int(input_calib_height)-120, input_calib_focal, sensor_name)
    
    os.makedirs(output_dir, exist_ok=True)
    
    num_frames = provider.get_num_data(stream_id)
    with tqdm(total=num_frames) as pbar:
        for count in range(provider.get_num_data(stream_id)):
            if(count % frame_step == 0):
                raw_image = provider.get_image_data_by_index(stream_id, count)[0].to_numpy_array()
                undistorted_image = calibration.distort_by_calibration(raw_image, dst_calib, src_calib)
                
                image = Image.fromarray(np.rot90(undistorted_image, k=3))
                image.save(f"{output_dir}/frame_{count:04d}.jpg")
            count += 1
            pbar.update(1)     
            
process_vrs(data_path, output_dir, 1) 