### FilterNet
A neural network for filtering out gaussian noise from an image introduced by aggressive accelerations on the vehicle carrying the camera. 

In [1]:
import os 
import torch 
import numpy as np 
import pandas as pd 
from scipy.spatial.transform import Rotation as R 
import matplotlib.pyplot as plt
from skimage import io, transform 
from PIL import Image    
from torch.utils.data import Dataset, DataLoader 
from torchvision import transforms, utils 
from typing import List, Dict, Tuple 
from tqdm import tqdm 



In [2]:
# Create dataset class 

class VideoDataset(Dataset):
    def __init__(self, path_to_data, augmented=False, num_vids=7):
        """
        Go through each video folder and build a map from 
        index i in range [0, N], where N is the total number
        of frames in all the videos, to a tuple (j, k), where 
        j is the video number and k the frame index in that 
        video. For reference, the videos are stored in the 
        dataset as:  

        data: 
        - test1: 
          - inputs:
            - bd_poses.csv
            - bd_twists.csv
            - frame0000.png
            - frame0001.png
            - ... 
          - targets: 
            - frame0000.png
            - frame0001.png
            - ... 
        - test2: 
            - ... 
        - ... 
        """ 
        self.augmented = augmented
        self.num_vids = num_vids 
        self.path_to_data = path_to_data 
        self.idx_map: List[Tuple[int, int]] = []
        self.poses: Dict[int, np.ndarray] = {}
        self.twists: Dict[int, np.ndarray] = {} 
        for i in tqdm(range(1, self.num_vids + 1)):
            # Define path to pose 
            pose_path = os.path.join(path_to_data, f'test{i}/inputs/bd_poses.csv')
            twist_path = os.path.join(path_to_data, f'test{i}/inputs/bd_twists.csv')

            # Check that files were opened properly 
            if not os.path.isfile(pose_path):
                raise FileNotFoundError(f"Missing pose file: {pose_path}")
            if not os.path.isfile(twist_path):
                raise FileNotFoundError(f"Missing twist file: {twist_path}")  

            # Get poses 
            self.poses[i] = pd.read_csv(pose_path).to_numpy() 
            self.twists[i] = pd.read_csv(twist_path).to_numpy() 

            # Get pose data frame 
            num_frames = self.poses[i].shape[0] 

            # Update index map 
            video_num = [i] * num_frames 
            frame_idx = list(range(0, num_frames))
            self.idx_map.extend(list(zip(video_num, frame_idx)))

        self.total_num_frames = len(self.idx_map)

    def __len__(self): 
        """
        Return length of dataset as computed in __init__() function. 
        """
        return self.total_num_frames 
    
    def __getitem__(self, idx): 
        """ 
        Use map built in __init__() to retrieve the image, 
        pose, and twist directly from the dataset. 
        This avoids loading the entire dataset which 
        overwhelms RAM. 
        """
        if torch.is_tensor(idx): 
            idx = idx.tolist() 

        # Define path to data 
        test_idx, frame_idx = self.idx_map[idx]
        padded_frame_idx = self.to_zero_pad_idx(frame_idx) 
        input_img_path = os.path.join(self.path_to_data,
                                f'test{test_idx}/inputs/frame{padded_frame_idx}.png')
        output_img_path = os.path.join(self.path_to_data,
                                f'test{test_idx}/targets/frame{padded_frame_idx}.png')
        
        # Load and process data 
        input_frame = io.imread(input_img_path)
        output_frame = io.imread(output_img_path)
        input_frame = self.to_grayscale(input_frame) 
        output_frame = self.to_grayscale(output_frame) 
        pose = self.poses[test_idx][frame_idx]
        pose = self.pose_vector_from_matrix(pose)
        twist = self.twists[test_idx][frame_idx]

        if self.augmented: 
            h, w = input_frame.shape 
            expanded_frame = torch.from_numpy(input_frame).float().view(h, w, 1) 
            expanded_pose = torch.from_numpy(pose).float().view(1, 1, -1).repeat(h, w, 1)
            expanded_twist = torch.from_numpy(twist).float().view(1, 1, -1).repeat(h, w, 1)  
            augmented_frame = torch.cat((expanded_frame, expanded_pose, expanded_twist), dim=-1)
            
            return {'input': augmented_frame, 'target': torch.from_numpy(output_frame).float()}

        return {'input': (torch.from_numpy(input_frame).float(), torch.from_numpy(pose).float(), torch.from_numpy(twist).float()), 'target': torch.from_numpy(output_frame).float()}

    def to_zero_pad_idx(self, idx):
        """ 
        Convert frame index from regular index to lexicographical index. 
        e.g. 1 -> 00001, 12 -> 00012 
        """
        return f'{idx:05d}'

    def pose_vector_from_matrix(self, pose): 
        """
        Convert 4x4 pose matrix (as a flattenned length 16 vector) into a position and quaternion length 7 vector. 
        """
        pose = pose.reshape(4, 4) 
        position = pose[:3, 3].reshape(3, 1) 
        orientation = pose[:3, :3]

        quat = R.from_matrix(orientation).as_quat().reshape(-1, 1) 
        norm_quat = quat / np.linalg.norm(quat) 

        return np.vstack((position, norm_quat)).reshape(-1)

    def to_grayscale(self, image):
        """ 
        Convert PNG image to grayscale
        """ 
        #TODO: Try float16 type 
        return (image[..., 0] > 127).astype(np.float32)
        

        

In [36]:
# Make dataloader and load data 
path_to_data = '/home/jrached/cv_project_code/project/data/filter_net/processed_flow'
dataset = VideoDataset(path_to_data, augmented=True)
loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=2) #NOTE: Don't remember what num_workers did nor why it matters 



100%|██████████| 7/7 [00:00<00:00, 165.48it/s]


In [37]:
datapoint = next(iter(loader))
features = datapoint['input']
labels = datapoint['target']


In [31]:
features.shape

torch.Size([64, 480, 848, 14])