In [49]:
# Setup Warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Global imports
import json
import matplotlib as mlp
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import random
import re
import seaborn as sns

# Typing import
from typing import List, Dict, Union, Tuple

# Specific imports
from copy import deepcopy
from matplotlib.patches import Rectangle, Ellipse
from matplotlib.animation import FuncAnimation, PillowWriter
from rich import print
from termcolor import cprint
from time import time

sns.set_theme('notebook')
sns.set_style("whitegrid")
sns.set_context("paper")
sns.color_palette("hls", 8)

def print_bl():
    print("\n")

def print_red(*args):
    for arg in args:
        cprint(arg, "red", end=' ')  # Using end=' ' to print all arguments on the same line

def print_green(*args):
    for arg in args:
        cprint(arg, "green", end=' ')  # Using end=' ' to print all arguments on the same line

def print_highlight(*args):
    for arg in args:
        cprint(arg, "magenta", "on_white", end=' ')  # Using end=' ' to print all arguments on the same line

def print_blue(*args):
    for arg in args:
        cprint(arg, "blue", end=' ')  # Using end=' ' to print all arguments on the same line

In [50]:
class GroupsLoader:
    """
    Class to load the groups from the json files.

    Attributes:
        groups_path (str): The path to the groups json files.
        dataset_index (int): The dataset index.
        groups (List[pd.DataFrame]): The list of groups dataframes.
    Methods:
        load_groups: Load the groups from the json files.
    """

    def __init__(self, groups_path: str, dataset_index: int = 1):
        self.groups_path = groups_path
        self.dataset_index = dataset_index
        self.groups = []
        self.ego_vehicles = []

    def load_groups(self) -> Tuple[List[pd.DataFrame], List[int]]:
        """
        Load the groups from the json files.

        Args:
            None
        Returns:
            A tuple containing the groups and the ego vehicles.
        """
        with open(self.groups_path + f"/groups_{self.dataset_index}.json", 'r') as f:
            data = json.load(f)
        self.groups = [pd.read_json(data[f]) for f in data.keys()]

        with open(self.groups_path + f"/ego_vehicles_{self.dataset_index}.json", 'r') as f:
            self.ego_vehicles = json.load(f)

        return self.groups, self.ego_vehicles

In [61]:
class PromptPopulator:
    """
    Class to populate the prompts.

    Takes the groups and ego vehicles and populates the prompt template. Saves the prompts in a json array.

    Attributes:
        groups_location (str): The path to the groups json files.
        dataset_index (int): The dataset index.
        prompts (List[str]): The list of populated prompt strings.
    """

    def __init__(self, groups_location: str, template_path: str, dataset_index: int = 1):
        self.groups_location = groups_location
        self.template_path = template_path
        self.dataset_index = dataset_index
        self.prompts = []

        self.groups_loader = GroupsLoader(self.groups_location, self.dataset_index)
        self.groups, self.ego_vehicles = self.groups_loader.load_groups()

        self.instructions_template, self.task_template, self.role_template = self.load_templates()

    def load_templates(self) -> List[str]:
        """
        Load the prompt templates.

        Args:
            None
        Returns:
            A list containing the instructions, task and role templates.
        """
        with open(self.template_path + "/instructions_template.txt", 'r') as f:
            instructions_template = f.read()
        with open(self.template_path + "/task_template.txt", 'r') as f:
            task_template = f.read()
        with open(self.template_path + "/role_template.txt", 'r') as f:
            role_template = f.read()
        
        return [instructions_template, task_template, role_template]
    
    def relative_group(self, group_index: int) -> pd.DataFrame:
        """
        Transforms all positions to the ego vehicle reference frame for each frame.

        Args:
            group_index (int): The index of the group to transform.
        Returns:
            A dataframe containing the transformed group.
        """
        ego_vehicle_id = self.ego_vehicles[group_index]
        group = self.groups[group_index]

        # Group by frame
        frame_groups = group.groupby('frame')

        transformed_groups = []
        
        for frame, frame_group in frame_groups:
            # Identify the ego vehicle for the current frame
            ego_vehicle = frame_group[frame_group['id'] == ego_vehicle_id].iloc[0]

            # These signs might not make sense at first - Draw a diagram with two two-lane roads. 
            # Upper two lanes are moving right to left, lower two lanes are moving left to right.
            
            if frame_group['xVelocity'].mean() < 0: #vehicles are moving right to left
                frame_group['x'] = ego_vehicle['x'] - frame_group['x'] 
                frame_group['y'] = - ego_vehicle['y'] + frame_group['y'] 
                frame_group['xVelocity'] = - frame_group['xVelocity']
                frame_group['xAcceleration'] = - frame_group['xAcceleration']
            else: #vehicles are moving left to right
                frame_group['x'] = frame_group['x'] - ego_vehicle['x'] 
                frame_group['y'] = - frame_group['y'] + ego_vehicle['y']
                frame_group['yVelocity'] = - frame_group['yVelocity']
                frame_group['yAcceleration'] = - frame_group['yAcceleration']
            
            transformed_groups.append(frame_group)
        
        transformed_group = pd.concat(transformed_groups)
        
        # Transform frame to go from max (t = 0) to min (t = -max)
        transformed_group['frame'] = - transformed_group['frame'].max() + transformed_group['frame']
        
        return transformed_group

    
    def row2str(self, row: pd.Series) -> str:
        """
        Convert a row to a string.

        Args:
            row (pd.Series): The row to convert.
        Returns:
            A string containing the row information.
        """
        info = (
            f"At t={row['frame']} s, vehicle with id {row['id']} is at position ({row['x']}, {row['y']}) with longitudinal speed "
            f"{row['xVelocity']} m/s and lateral speed {row['yVelocity']} m/s. The longitudinal acceleration is {row['xAcceleration']} m/s^2 "
            f"and the lateral acceleration is {row['yAcceleration']} m/s^2. The length of the vehicle is {row['width']} m and its width is "
            f"{row['height']} m."
        )

        return info
    
    def get_prompt_static_info(self, group_index: int = 0) -> Dict[str, Union[str, int]]:
        """
        Get the static information for the prompt.

        Args:
            group_index (int): The index of the group to get the static information.
        Returns:
            A dictionary containing the static information for the prompt.
        """
        group = self.relative_group(group_index)
        ego_vehicle_id = self.ego_vehicles[group_index]

        info = "Vehicles present in the group have ids: " + ", ".join([str(id) for id in group['id'].unique()]) + ". "
        info += f"The ego vehicle is vehicle with id {ego_vehicle_id}."

        return info

    def populate_prompt(self, group_index: int = 0) -> str:
        """
        Populate the prompts. Takes a template and a vehicle group and the prompt template, and returns a string with the populated prompt.

        Args:
            group_index (int): The index of the group to populate.
        Returns:
            A string with the populated prompt.
        """
        group = self.relative_group(group_index)
        prompt = (
            self.role_template + "\n\n" + 
            self.task_template + "\n\n" + 
            self.instructions_template + "\n\n" +
            self.get_prompt_static_info(group_index) + "\n\n" +
            "The information for each vehicle is as follows:\n"
        )

        frame_groups = group.groupby('frame')

        for frame, frame_group in frame_groups:
            for index, row in frame_group.iterrows():
                prompt += self.row2str(row) + "\n"
                
        return prompt
    
    def save_prompt_to_file(self, group_index: int = 0, filename: str = 'prompt.txt'):
        """
        Save the populated prompt to a file.

        Args:
            group_index (int): The index of the group to populate.
            filename (str): The name of the file to save the prompt.
        """
        prompt = self.populate_prompt(group_index)
        with open(filename, 'w') as file:
            file.write(prompt)


In [62]:
groups_location = "/Users/lmiguelmartinez/Tesis/datasets/highD/groups_1000_lookback5"
template_path = "./prompts"

pp = PromptPopulator(groups_location=groups_location, template_path=template_path, dataset_index=1)
prompt = pp.populate_prompt(1)

In [63]:
pp.save_prompt_to_file(group_index=1, filename='./test_save/prompt.txt')

In [60]:
pp.groups[1]

Unnamed: 0,frame,id,x,y,width,height,xVelocity,yVelocity,xAcceleration,yAcceleration,...,precedingXVelocity,precedingId,followingId,leftPrecedingId,leftAlongsideId,leftFollowingId,rightPrecedingId,rightAlongsideId,rightFollowingId,laneId
0,1,3,216.16,21.63,3.94,1.92,35.88,-0.03,0.2,-0.02,...,41.81,14,7,0,0,0,10,0,11,5
1,1,7,182.64,22.02,4.75,2.02,32.91,-0.16,0.37,0.05,...,35.88,3,12,0,0,0,11,0,0,5
2,1,10,223.57,25.63,9.2,2.5,23.36,-0.09,0.1,-0.09,...,23.26,6,11,14,0,3,0,0,0,6
3,1,11,205.24,25.52,4.14,1.92,24.8,-0.05,-0.29,-0.09,...,23.36,10,0,3,0,7,0,0,0,6
4,2,3,252.16,21.58,3.94,1.92,36.05,-0.04,0.13,0.0,...,42.24,14,7,0,0,0,6,10,11,5
5,2,7,215.76,21.89,4.75,2.02,33.28,-0.15,0.36,0.01,...,36.05,3,12,0,0,0,11,0,0,5
6,2,10,247.03,25.5,9.2,2.5,23.35,-0.07,-0.12,0.02,...,23.31,6,11,14,3,7,0,0,0,6
7,2,11,229.9,25.4,4.14,1.92,24.49,-0.11,-0.33,0.02,...,23.35,10,0,3,0,7,0,0,0,6
8,3,3,288.28,21.55,3.94,1.92,36.14,-0.04,0.08,0.01,...,42.74,14,7,0,0,0,6,0,10,5
9,3,7,249.21,21.72,4.75,2.02,33.63,-0.14,0.35,0.07,...,36.14,3,12,0,0,0,11,0,0,5
