In [78]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import json
import matplotlib as mlp
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns

from matplotlib.patches import Rectangle, Ellipse
from rich import print
from time import time

from typing import List, Dict, Union

from read_csv import read_meta_info

sns.set_theme('notebook')
sns.set_style("whitegrid")
sns.set_context("paper")
sns.color_palette("hls", 8)

def print_bl():
    print("\n")

In [140]:
class PromptExtractor:
    """
    A class that represents an entire recording of the highD dataset.

    Attributes:
        dataset_location: a path to the directory in which the dataset is stored.
        dataset_index: index of the recording that is to be addressed
        sampling_period: time period between sampled frames. Needs to be a multiple of 40 ms    
    
    """
    def __init__(self, dataset_location: str = None, dataset_index: int = None):
        self.dataset_index  = dataset_index
        self.dataset_location = dataset_location
        self.frame_length = 40 # each frame is 40 ms
        self.df_location = dataset_location + str(dataset_index).zfill(2) + "_tracks.csv"
        self.static_info_location = dataset_location + str(dataset_index).zfill(2) + "_tracksMeta.csv"
        self.video_info_location = dataset_location + str(dataset_index).zfill(2) + "_recordingMeta.csv"

        self.data = pd.read_csv(self.df_location)
        self.static_info = pd.read_csv(self.static_info_location)
        self.video_info = read_meta_info(self.video_info_location)

        #get dict of numLaneChanges
        print(self.static_info.columns)
        self.num_lane_changes = self.static_info[['id', 'numLaneChanges']].set_index('id').to_dict()['numLaneChanges']
        display(self.num_lane_changes)

        
    def filter_data(self, sampling_period: int = 1000) -> pd.DataFrame:
        """
        Filters the dataset with a certain sampling period.

        Args:
            sampling_period (int): The time spacing between each frame to consider.

        Returns:
            A list of dictionaries with one frame per dictionary.
        """
        
        self.sampling_period = sampling_period
        self.frame_spacing = int(self.sampling_period / self.frame_length) #Frames are 40 ms apart

        # Filter data
        self.data = self.data[self.data.frame % self.frame_spacing == 0]
        self.data.frame = self.data.frame / self.frame_spacing
        self.data = self.data.astype({'frame': 'int16'})

    def get_frame_windows(self, window_size: int = 5) -> List[pd.DataFrame]:
        """
        Splits the dataset into windows of a certain size.

        Args:
            window_size (int): The size of the window to consider.

        Returns:
            A list of dataframes with the windows.
        """

        self.window_size = window_size
        self.windows = []

        for i in range(1, len(self.data), window_size):
            window = self.data[self.data['frame'].isin(range(i, i+window_size))]
            self.windows.append(window) #no overlap
        self.windows[0].head()

    def get_groups(self, bubble_radius: float = 50.0, lookback: int = 5):
        def in_bubble(ego_vehicle, x, radius = bubble_radius):
            '''
            Calculates whether a vehicle is in the bubble of the ego vehicle. Checks distance between ego vehicle and target, as well as driving direction.
            Notation: 
                (x, y) - coordinates of the top left corner of the bounding box of the vehicle.
                (w, h) - width and height of the bounding box.

                Parameters:
                    ego_vehicle: dataframe row with information about the ego vehicle
                    x: dataframe row with information about the vehicle to consider
                Returns:
                    bool: True if the vehicle is in the bubble, False otherwise
                    
            '''
            x1, y1, w1, h1 = ego_vehicle.x, ego_vehicle.y, ego_vehicle.width, ego_vehicle.height
            x2, y2, w2, h2 = x.x, x.y, x.width, x.height
            c1 = np.array([x1 + w1/2, y1 + h1/2])
            c2 = np.array([x2 + w2/2, y2 + h2/2])
            dist = np.linalg.norm(c1-c2)

            return dist < radius & np.sign(x.xVelocity) == np.sign(ego_vehicle.xVelocity)
        
        self.bubble_radius = bubble_radius
        self.lookback = lookback

        self.groups = []
        ego_candidates = []

        for window in self.windows:
            #choose ego vehicle - one present in every frame
            ego_candidates.clear()
            id_counts = window.id.value_counts().to_dict()
            [ego_candidates.append(vehicle_id) for vehicle_id in id_counts.keys() if id_counts[vehicle_id] == lookback]
            if not ego_candidates:
                continue

            break
            

In [141]:
dataset_location = "/home/lmmartinez/Tesis/datasets/highD/data/"
dataset_index = 1
start = time()
scene_data = PromptExtractor(dataset_location=dataset_location, dataset_index=dataset_index)
end = time()
print("Time elapsed is:", end - start)
start = time()
scene_data.filter_data(sampling_period=1000)
scene_data.get_frame_windows(window_size=5)
scene_data.get_groups()
end = time()
print("Time elapsed is:", end - start)

{1: 0,
 2: 0,
 3: 0,
 4: 0,
 5: 0,
 6: 0,
 7: 0,
 8: 0,
 9: 1,
 10: 0,
 11: 1,
 12: 0,
 13: 0,
 14: 0,
 15: 0,
 16: 0,
 17: 0,
 18: 0,
 19: 0,
 20: 0,
 21: 0,
 22: 0,
 23: 1,
 24: 0,
 25: 0,
 26: 0,
 27: 0,
 28: 0,
 29: 0,
 30: 1,
 31: 1,
 32: 0,
 33: 0,
 34: 0,
 35: 0,
 36: 0,
 37: 0,
 38: 0,
 39: 0,
 40: 0,
 41: 0,
 42: 0,
 43: 0,
 44: 0,
 45: 0,
 46: 0,
 47: 0,
 48: 1,
 49: 0,
 50: 0,
 51: 0,
 52: 0,
 53: 0,
 54: 0,
 55: 0,
 56: 0,
 57: 0,
 58: 1,
 59: 0,
 60: 1,
 61: 0,
 62: 0,
 63: 0,
 64: 0,
 65: 0,
 66: 0,
 67: 0,
 68: 0,
 69: 0,
 70: 0,
 71: 0,
 72: 0,
 73: 0,
 74: 0,
 75: 0,
 76: 0,
 77: 0,
 78: 0,
 79: 1,
 80: 0,
 81: 0,
 82: 0,
 83: 0,
 84: 0,
 85: 0,
 86: 0,
 87: 0,
 88: 0,
 89: 0,
 90: 0,
 91: 0,
 92: 0,
 93: 1,
 94: 0,
 95: 0,
 96: 1,
 97: 1,
 98: 0,
 99: 1,
 100: 0,
 101: 0,
 102: 0,
 103: 0,
 104: 0,
 105: 0,
 106: 0,
 107: 0,
 108: 0,
 109: 0,
 110: 0,
 111: 1,
 112: 0,
 113: 0,
 114: 0,
 115: 1,
 116: 1,
 117: 0,
 118: 1,
 119: 1,
 120: 0,
 121: 0,
 122: 0,
 123: 2,
 

In [117]:
scene_data.windows[0].to_csv("window.csv")

In [130]:
scene_data.static_info.head(100)

Unnamed: 0,id,width,height,initialFrame,finalFrame,numFrames,class,drivingDirection,traveledDistance,minXVelocity,maxXVelocity,meanXVelocity,minDHW,minTHW,minTTC,numLaneChanges
0,1,4.85,2.12,1,33,33,Car,2,52.25,40.85,41.30,41.07,-1.00,-1.00,-1.00,0
1,2,4.24,1.92,1,130,130,Car,1,167.44,32.04,32.90,32.48,112.62,3.51,-1.00,0
2,3,3.94,1.92,1,157,157,Car,2,225.23,35.69,36.50,36.13,90.31,2.53,-1.00,0
3,4,5.05,2.22,1,161,161,Car,1,273.49,42.57,42.83,42.76,-1.00,-1.00,-1.00,0
4,5,4.24,1.82,1,182,182,Car,1,313.92,42.50,44.26,43.40,33.64,0.76,24.53,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,96,4.04,1.92,1791,2066,276,Car,1,409.57,36.72,37.65,37.26,39.80,1.07,4.91,1
96,97,5.05,2.12,1828,2102,275,Car,2,402.76,35.61,37.11,36.77,52.36,1.41,15.80,1
97,98,6.67,2.73,1836,2212,377,Truck,2,408.09,26.95,27.39,27.15,26.38,0.97,-1.00,0
98,99,4.45,2.02,1857,2150,294,Car,2,410.35,34.36,35.56,35.03,35.76,1.04,-1.00,1
