# Loading Data and Visualization

Make sure you have set up your environment `ritw` using following the instructions. 
```
   git clone git@github.com:facebookresearch/reading_in_the_wild.git
   cd reading_in_the_wild
   conda create -n ritw python=3.10
   conda activate ritw
   pip install -r requirements.txt
```

In [None]:
import os
import argparse
import torch
import cv2
import pandas as pd
import numpy as np
import json
from projectaria_tools.core import data_provider
from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions
from projectaria_tools.core.stream_id import StreamId
from torch.utils.data import DataLoader
import glob
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image

# We will also use utility functions defined in the models directory
import sys
sys.path.append("../models")
from model import MultimodalTransformer
from projection_utils import project_gaze
from utils import create_sampled_array, draw_projections_on_image

## Loading a Sequence
Make sure you have downloaded a sequence with vrs, mps, and metadata. 

In [None]:
root_dir = "your root directory"
vid_uid = "recording_925696276074411" # replace with any sequence that you have downloaded
recordings_dir = os.path.join(root_dir, vid_uid)

# Load Metadata
metadata_path = os.path.join(recordings_dir, "metadata.json")
with open(metadata_path, "r") as f:
    metadata = json.load(f)

In [None]:
# Load Eye Gaze
vrs_path = os.path.join(recordings_dir, "recording.vrs")
gaze_path = os.path.join(recordings_dir, "mps", "eye_gaze", "personalized_eye_gaze.csv")
if not os.path.exists(gaze_path):
    gaze_path = os.path.join(recordings_dir, "mps", "eye_gaze", "general_eye_gaze.csv")
gaze = project_gaze(gaze_path, vrs_path=vrs_path)

# Load RGB
provider = data_provider.create_vrs_data_provider(vrs_path)
deliver_option = provider.get_default_deliver_queued_options()
deliver_option.deactivate_stream_all()
deliver_option.activate_stream(StreamId("214-1"))

### Let's preview RGB thumbnails

In [None]:
sample_count = 10 # how many samples to visualize
resize_ratio = 10 # reduce the image size by 10 to generate thumbnails

rgb_stream_id = StreamId("214-1")
time_domain = TimeDomain.DEVICE_TIME  # query data based on host time
option = TimeQueryOptions.CLOSEST
start_time = provider.get_first_time_ns(rgb_stream_id, time_domain)
end_time = provider.get_last_time_ns(rgb_stream_id, time_domain)
sample_timestamps = np.linspace(start_time, end_time, sample_count)

image_config = provider.get_image_configuration(rgb_stream_id)
width = image_config.image_width
height = image_config.image_height

thumbnail = newImage = Image.new(
    "RGB", (int(width * sample_count / resize_ratio), int(height / resize_ratio))
)
current_width = 0

for sample in sample_timestamps:
    image_tuple = provider.get_image_data_by_time_ns(rgb_stream_id, int(sample), time_domain, option)
    image_array = image_tuple[0].to_numpy_array()
    image = Image.fromarray(image_array)
    new_size = (
        int(image.size[0] / resize_ratio),
        int(image.size[1] / resize_ratio),
    )
    image = image.resize(new_size).rotate(-90)
    thumbnail.paste(image, (current_width, 0))
    current_width = int(current_width + width / resize_ratio)

from IPython.display import Image
display(thumbnail)

### Visualizing the gaze trajectories and the foveated patches over the RGB frames

In [None]:
input_hz = 60  # input gaze frequency
input_sec = 2  # snippet (sample) duration that is fed to the model
crop_size = 64  # resolution of foveated rgb patch (5 degree FoV)
input_length = input_hz * input_sec

gaze_sequence = gaze[['transformed_gaze_x', 'transformed_gaze_y', 'transformed_gaze_z']].ffill()
gaze_sequence = create_sampled_array(gaze_sequence, num_samples=input_length+1, stride=60//input_hz)
gaze_sequence = torch.Tensor(np.diff(gaze_sequence, axis=1) * input_hz)
num_gaze = gaze_sequence.size(0)

gaze_xy = np.array(gaze[['projected_point_2d_x', 'projected_point_2d_y']].ffill())
gaze_timestamps = gaze['tracking_timestamp_us'].tolist()

In [None]:
# Get a short snippet, and plot projected gazes (colored dots) and the rgb crop (red)

import mediapy as media

i = len(gaze_timestamps) // 2  # i-th gaze sample. Here we chose the center of the sequence.
frames = []
for j in range(-20, 20): # let's visualize +/- 20 samples around it
    gaze_idx = i + j + input_length
    time = gaze_timestamps[gaze_idx] * 1000
    im = provider.get_image_data_by_time_ns(StreamId("214-1"), time, TimeDomain.DEVICE_TIME, TimeQueryOptions.CLOSEST)[0].to_numpy_array()
    im = cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)      
    x_ = 1408 - np.clip(int(gaze_xy[gaze_idx,0]), crop_size//2, 1408-crop_size//2)
    y_ = np.clip(int(gaze_xy[gaze_idx,1]), crop_size//2, 1408-crop_size//2)
    gaze_crop = im[y_-crop_size//2:y_+crop_size//2, x_-crop_size//2:x_+crop_size//2]

    im_draw = draw_projections_on_image(im, gaze_xy[gaze_idx-60:gaze_idx:6,0],gaze_xy[gaze_idx-60:gaze_idx:6,1])
    cv2.rectangle(im_draw, (x_-crop_size//2, y_-crop_size//2), (x_+crop_size//2, y_+crop_size//2), color=(255, 0,0), thickness=5)
    im_draw = cv2.resize(im_draw, (704,704))

    frames.append(im_draw)

# Display the video
output_vid_path = '/tmp/viz.mp4'
media.write_video(output_vid_path, frames, fps=30)
media.show_video(media.read_video(output_vid_path), fps=30)
