### FilterNet
A neural network for filtering out gaussian noise from an image introduced by aggressive accelerations on the vehicle carrying the camera. 

In [None]:
import os 
import torch 
import numpy as np 
import pandas as pd 
from scipy.spatial.transform import Rotation as R 
from skimage import io, transform 
from PIL import Image   
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset, Dataloader 
from torchvision import transforms, utils 

In [None]:
# Create dataloader 

#TODO: Load poses and twists once for each video in the __init__() function.
# This takes up more RAM but significantly reduces input and output overhead. 

class VideoDataset(Dataset):
    def __init__(self, path_to_data, num_vids=7):
        """
        Go through each video folder and build a map from 
        index i in range [0, N], where N is the total number
        of frames in all the videos, to a tuple (j, k), where 
        j is the video number and k the frame index in that 
        video. For reference, the videos are stored in the 
        dataset as:  

        data: 
        - test1: 
          - video:
            - frame0001.png
            - ... 
          - poses: 
            - poses.csv
            - twists.csv
        - test2: 
            - ... 
        - ... 
        """ 
        self.num_vids = num_vids 
        self.path_to_data = path_to_data 
        self.idx_map = []
        for i in range(1, self.num_vids + 1):
            # Get pose data frame 
            num_frames = pd.read_csv(path_to_data + f'/test{i}/poses/poses.csv').to_numpy().shape[0] 

            # Update index map 
            video_num = [i] * num_frames 
            frame_idx = list(range(0, num_frames))
            self.idx_map += list(zip(video_num, frame_idx))

        self.total_num_frames = len(self.idx_map)

    def __len__(self): 
        """
        Return length of dataset as computed in __init__() function. 
        """
        return self.total_num_frames 
    
    def __getitem__(self, idx): 
        """ 
        Use map built in __init__() to retrieve the image, 
        pose, and twist directly from the dataset. 
        This avoids loading the entire dataset which 
        overwhelms the RAM. 
        """
        if torch.is_tensor(idx): 
            idx = idx.tolist() 

        test_idx, frame_idx = self.idx_map[idx]
        lexi_frame_idx = self.to_lexi_idx(frame_idx) 
        img_path = os.path.join(self.path_to_data,
                                f'/test{test_idx}/video/frame{lexi_frame_idx}.png')
        pose_path = os.path.join(self.path_to_data, 
                                 f'/test{test_idx}/poses/pose.csv')
        twist_path = os.path.join(self.path_to_data, 
                                 f'/test{test_idx}/poses/twist.csv')
        
        frame = io.imread(img_path)
        frame = self.to_grayscale_bitmask(frame) 
        pose = pd.read_csv(pose_path).to_numpy()[frame_idx]
        pose = self.pose_matrix_to_vector(pose)
        twist = pd.read_csv(twist_path).to_numpy()[frame_idx]

        return frame, torch.from_numpy(pose), torch.from_numpy(twist) 

    def to_lexi_idx(self, idx):
        """ 
        Convert frame index from regular index to lexicographical index. 
        e.g. 1 -> 00001, 12 -> 00012 
        """
        #TODO: Implement.
        pass 

    def pose_matrix_to_vector(self, pose): 
        """
        Convert 4x4 pose matrix into a position and quaternion length 7 vector. 
        """
        position = pose[:3, 3].reshape(3, 1) 
        orientation = pose[:3, :3]

        quat = R.from_matrix(orientation).as_quat().reshape(-1, 1) 
        norm_quat = quat / np.linalg.norm(quat) 

        return np.hstack((position, norm_quat))

    def to_grayscale_bitmask(self, image):
        """ 
        Convert PNG image to grayscale bitmask 
        """ 
        #TODO: Implement. 
        pass 

        
        

In [None]:
path_to_data = ''
dataset = VideoDataset(path_to_data)
loader = Dataloader(dataset, batch_size=64, shuffle=False, num_workers=4) #NOTE: Don't remember what num_workers did nor why it matters 

