In [1]:
# Setup Warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Global imports
import json
import matplotlib as mlp
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import re
import seaborn as sns

# Typing import
from typing import List, Dict, Union

# Specific imports
from matplotlib.patches import Rectangle, Ellipse
from rich import print
from termcolor import cprint
from time import time

# Local imports
from read_csv import read_meta_info

sns.set_theme('notebook')
sns.set_style("whitegrid")
sns.set_context("paper")
sns.color_palette("hls", 8)

def print_bl():
    print("\n")


def print_red(*args):
    for arg in args:
        cprint(arg, "red", end=' ')  # Using end=' ' to print all arguments on the same line
    print()

def print_green(*args):
    for arg in args:
        cprint(arg, "green", end=' ')  # Using end=' ' to print all arguments on the same line
    print()

def print_highlight(*args):
    for arg in args:
        cprint(arg, "magenta", "on_white", end=' ')  # Using end=' ' to print all arguments on the same line
    print()

def print_blue(*args):
    for arg in args:
        cprint(arg, "light_blue", end=' ')  # Using end=' ' to print all arguments on the same line
    print()

In [2]:
class PromptExtractor:
    """
    A class that represents an entire recording of the highD dataset.

    Attributes:
        dataset_location: a path to the directory in which the dataset is stored.
        dataset_index: index of the recording that is to be addressed
        data: the raw data of the recording.
        static_info: the static information of the recording.
        video_info: the video information of the recording.
        frame_length: the length of each frame in milliseconds.
        sampling_period: the time spacing between each frame to consider.
        frame_spacing: the number of frames to skip between each frame.
        window_size: the size of the window to consider.
        windows: a list of dataframes with the windows.
        num_lane_changes: a dictionary with the number of lane changes for each vehicle.
        lane_changers: a list of vehicles that change lanes.
        ego_vehicles: a list of ego vehicles - one for each window.
        groups: a list of vehicle groups.

    Methods:        
        filter_data: Filters the dataset with a certain sampling period.
        get_frame_windows: Splits the dataset into windows of a certain size.
        get_ego_vehicles: Selects the ego vehicle for each window.
        get_groups: Groups vehicles in each window based on their proximity to the ego vehicle.
    """
    def __init__(self, dataset_location: str = None, dataset_index: int = None):
        """
        Initializes the PromptExtractor class.

        Args:
            dataset_location (str): The path to the directory in which the dataset is stored.
            dataset_index (int): The index of the recording that is to be addressed.
        """
        ### Error handling
        if dataset_location is None:
            raise ValueError("Please provide a dataset location.")
        if dataset_index is None or dataset_index < 1 or dataset_index > 60:
            raise ValueError("Please provide a dataset index between 1 and 60")
        
        # retrieve raw data
        self.dataset_index  = dataset_index
        self.dataset_location = dataset_location
        self.df_location = dataset_location + str(dataset_index).zfill(2) + "_tracks.csv"
        self.static_info_location = dataset_location + str(dataset_index).zfill(2) + "_tracksMeta.csv"
        self.video_info_location = dataset_location + str(dataset_index).zfill(2) + "_recordingMeta.csv"

        self.data = pd.read_csv(self.df_location)
        self.static_info = pd.read_csv(self.static_info_location)
        self.video_info = read_meta_info(self.video_info_location)

        #initialize attributes
        self.frame_length = 40 # each frame is 40 ms
        self.sampling_period = 0
        self.frame_spacing = 0.0

        self.window_size = 0
        self.windows = []
        self.new_windows = []

        self.num_lane_changes = {}
        self.lane_changers = []

        self.ego_vehicles = []
        self.groups = []

        
    def filter_data(self, sampling_period: int = 1000) -> pd.DataFrame:
        """
        Filters the dataset with a certain sampling period.

        Args:
            sampling_period (int): The time spacing between each frame to consider.
        Returns:
            The filtered dataset in the form of a pandas dataframe.
        """
        ### Argument validation
        if sampling_period % self.frame_length != 0:
            raise ValueError("Sampling period must be a multiple of 40ms.")
        
        self.sampling_period = sampling_period
        self.frame_spacing = int(self.sampling_period / self.frame_length) #Frames are 40 ms apart

        ### Printing parameters
        print("Filtering data with the following parameters:")
        print_green(f"Sampling period: {self.sampling_period} ms")
        print_green(f"Frame spacing: {self.frame_spacing} frames")

        self.data = self.data[self.data.frame % self.frame_spacing == 0]
        self.data.frame = self.data.frame / self.frame_spacing
        self.data = self.data.astype({'frame': 'int16'})

        return self.data

    def get_frame_windows(self, window_size: int = 5) -> List[pd.DataFrame]:
        """
        Splits the dataset into windows of a certain size.

        Args:
            window_size (int): The size of the window to consider.

        Returns:
            A list of dataframes with the windows.
        """
        ### Runtime error handling
        if self.sampling_period == 0:
            raise RuntimeError("Data has not been filtered - Call filter_data() first with the desired sampling period.")
        
        ### Argument validation            
        if window_size < 1:
            raise ValueError("Window size must be greater than 0.")
        if window_size > len(self.data):
            raise ValueError("Window size must be less than the length of the dataset.")

        self.window_size = window_size

        ### Printing parameters
        print("Creating windows with the following parameters:")
        print_green(f"Window size: {self.window_size} frames")

        self.windows.clear() # clear windows list
        for i in range(1, len(self.data), window_size):
            window = self.data[self.data['frame'].isin(range(i, i+window_size))]
            self.windows.append(window) #no overlap
        return self.windows
    
    def get_ego_vehicles(self) -> List[int]:
        """
        Selects the ego vehicle for each window. The ego vehicle is chosen as a vehicle that is present in all frames of the window and changes lanes at least once.

        Returns:
            A list of ego vehicles - one for each window.
        """
        ### Runtime error handling
        if not self.windows:
            raise RuntimeError("No windows have been created. Please run get_frame_windows() first.")

        lookback = self.window_size
        ego_candidates = []
        all_present = []
        defective_windows = 0
        
        ### Printing parameters
        print("Selecting ego vehicles with the following parameters:")
        print_green(f"Lookback: {lookback} frames")

        #get vehicles that change lanes at least once
        self.num_lane_changes = self.static_info[['id', 'numLaneChanges']].set_index('id').to_dict()['numLaneChanges'] # convert df with id and numLaneChanges to dict [id: numLaneChanges]
        #convert into list of ids that change lanes 
        self.lane_changers = [k for k, v in self.num_lane_changes.items() if v > 0]

        for i, window in enumerate(self.windows):
            #choose ego vehicle 
            ego_candidates.clear()
            all_present.clear()
            id_counts = window.id.value_counts().to_dict()
            all_present = [vehicle_id for vehicle_id in id_counts.keys() if id_counts[vehicle_id] == lookback] #ensure ego vehicle is present in all frames
            if not all_present:
                defective_windows += 1
                continue
            #get ego_candidates that change lane at least once and choose one of them
            ego_candidates = list(set(all_present) & set(self.lane_changers))
            if not ego_candidates: #if there are no lane changers, choose random ego vehicle
                ego_candidate = random.choice(all_present)
                self.ego_vehicles.append(ego_candidate)
                self.new_windows.append(window)
            else: 
                for ego_candidate in ego_candidates:
                    self.new_windows.append(window)
                    self.ego_vehicles.append(ego_candidate) 

        self.windows = self.new_windows

        # Issue warning if there are more than 10% defective windows
        if defective_windows > 0.1 * len(self.windows):
            warnings.warn(f"More than 10% of windows are defective. Defective windows: {defective_windows} - Total windows: {len(self.windows)} - Percentage: {defective_windows/len(self.windows) * 100}%")

        return self.ego_vehicles

    def get_groups(self, bubble_radius: float|int = 50) -> List[pd.DataFrame]:
        """
        Groups vehicles in each window based on their proximity to the ego vehicle. Uses the bubble radius to determine whether a vehicle is in the bubble of the ego vehicle.
        Args:
            bubble_radius (float): The radius of the bubble around the ego vehicle.

        Returns:
            A list of vehicle groups.
        """
        def in_bubble(ego_vehicle, x, radius = bubble_radius):
            '''
            Calculates whether a vehicle is in the bubble of the ego vehicle. Checks distance between ego vehicle and target, as well as driving direction.
            Notation: 
                (x, y) - coordinates of the top left corner of the bounding box of the vehicle.
                (w, h) - width and height of the bounding box.

                Parameters:
                    ego_vehicle: dataframe row with information about the ego vehicle
                    x: dataframe row with information about the vehicle to consider
                Returns:
                    bool: True if the vehicle is in the bubble, False otherwise
                    
            '''
            x1, y1, w1, h1 = ego_vehicle.x, ego_vehicle.y, ego_vehicle.width, ego_vehicle.height
            x2, y2, w2, h2 = x.x, x.y, x.width, x.height
            c1 = np.array([x1 + w1/2, y1 + h1/2])
            c2 = np.array([x2 + w2/2, y2 + h2/2])
            dist = np.linalg.norm(c1-c2)
            dist_check = dist < radius
            sign_check = np.sign(x.xVelocity) == np.sign(ego_vehicle.xVelocity)
            ret = dist_check & sign_check

            return ret
        
        ### Runtime error handling
        if not self.ego_vehicles:
            raise RuntimeError("No ego vehicles have been selected. Please run get_ego_vehicles() first.")
        
        ### Argument validation
        if bubble_radius <= 0:
            raise ValueError("Bubble radius must be greater than 0.")
        
        bubble_radius = float(bubble_radius) # convert to float if int
        
        self.bubble_radius = bubble_radius

        ### Printing parameters
        print("Grouping vehicles with the following parameters:")
        print_green(f"Bubble radius: {bubble_radius} meters")

        current_group = pd.DataFrame(columns=self.data.columns)
        for i, window in enumerate(self.windows):
            current_group = current_group.iloc[0:0] # clear current_group
            #iterate through each frame in the window
            for frame_num, df_group in window.groupby("frame"): #separate the window by frame and iterate through each frame
                ego_vehicle = df_group[df_group.id == self.ego_vehicles[i]]
                #apply mask to window to get vehicles in bubble and concatenate to current_group
                current_group = pd.concat([current_group, df_group[df_group.apply(lambda x: in_bubble(ego_vehicle.iloc[0], x), axis=1)]])
            self.groups.append(current_group) # append current_group to groups

        return self.groups

In [7]:
dataset_location = "/home/lmmartinez/Tesis/datasets/highD/data/"
dataset_index = 1
start = time()
scene_data = PromptExtractor(dataset_location=dataset_location, dataset_index=dataset_index)
end = time()
print("Time elapsed is:", end - start)
start = time()
scene_data.filter_data(sampling_period=1000)
end = time()
print("Time elapsed is:", end - start)
start = time()
scene_data.get_frame_windows(window_size=4)
end = time()
print("Time elapsed is:", end - start)
start = time()
scene_data.get_ego_vehicles()
end = time()
print("Time elapsed is:", end - start)
start = time()
scene_data.get_groups(bubble_radius=50)
end = time()
print("Time elapsed is:", end - start)

[32mSampling period: 1000 ms[0m 

[32mFrame spacing: 25 frames[0m 

[32mWindow size: 4 frames[0m 

[32mLookback: 4 frames[0m 



[32mBubble radius: 50.0 meters[0m 

IndexError: list index out of range

In [5]:
scene_data.groups[0]

Unnamed: 0,frame,id,x,y,width,height,xVelocity,yVelocity,xAcceleration,yAcceleration,...,precedingXVelocity,precedingId,followingId,leftPrecedingId,leftAlongsideId,leftFollowingId,rightPrecedingId,rightAlongsideId,rightFollowingId,laneId
505,1,5,268.74,13.78,4.24,1.82,-42.93,-0.04,-0.43,0.01,...,-42.8,4,0,0,0,0,8,0,9,3
1264,1,9,283.04,9.45,4.85,2.02,-36.16,0.15,-0.18,0.27,...,-31.51,8,0,5,0,0,0,0,0,2
530,2,5,225.57,13.75,4.24,1.82,-43.27,-0.04,-0.23,-0.01,...,-42.76,4,0,0,0,0,8,0,9,3
1289,2,9,246.8,9.78,4.85,2.02,-36.33,0.51,-0.15,0.29,...,-31.5,8,0,5,0,0,0,0,0,2
555,3,5,182.19,13.68,4.24,1.82,-43.4,-0.06,-0.04,0.03,...,-42.77,4,0,0,0,0,8,0,9,3
1314,3,9,210.38,10.44,4.85,2.02,-36.44,0.78,-0.08,0.16,...,-31.4,8,0,5,0,0,0,0,0,2
580,4,5,138.77,13.66,4.24,1.82,-43.39,0.03,0.0,0.05,...,-42.82,4,17,0,0,0,8,0,9,3
1339,4,9,173.91,11.28,4.85,2.02,-36.51,0.84,-0.07,-0.02,...,-31.28,8,16,5,0,17,0,0,0,2
605,5,5,95.43,13.73,4.24,1.82,-43.47,0.04,-0.2,-0.04,...,-42.79,4,9,0,0,0,8,0,16,3
1364,5,9,137.38,12.07,4.85,2.02,-36.59,0.72,-0.07,-0.13,...,-43.47,5,17,0,0,0,8,0,16,3


In [130]:
scene_data.static_info.head(100)

Unnamed: 0,id,width,height,initialFrame,finalFrame,numFrames,class,drivingDirection,traveledDistance,minXVelocity,maxXVelocity,meanXVelocity,minDHW,minTHW,minTTC,numLaneChanges
0,1,4.85,2.12,1,33,33,Car,2,52.25,40.85,41.30,41.07,-1.00,-1.00,-1.00,0
1,2,4.24,1.92,1,130,130,Car,1,167.44,32.04,32.90,32.48,112.62,3.51,-1.00,0
2,3,3.94,1.92,1,157,157,Car,2,225.23,35.69,36.50,36.13,90.31,2.53,-1.00,0
3,4,5.05,2.22,1,161,161,Car,1,273.49,42.57,42.83,42.76,-1.00,-1.00,-1.00,0
4,5,4.24,1.82,1,182,182,Car,1,313.92,42.50,44.26,43.40,33.64,0.76,24.53,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,96,4.04,1.92,1791,2066,276,Car,1,409.57,36.72,37.65,37.26,39.80,1.07,4.91,1
96,97,5.05,2.12,1828,2102,275,Car,2,402.76,35.61,37.11,36.77,52.36,1.41,15.80,1
97,98,6.67,2.73,1836,2212,377,Truck,2,408.09,26.95,27.39,27.15,26.38,0.97,-1.00,0
98,99,4.45,2.02,1857,2150,294,Car,2,410.35,34.36,35.56,35.03,35.76,1.04,-1.00,1
