In [None]:
!pip install -q transformers huggingface_hub[cli] bitsandbytes accelerate

In [None]:
# Install environment and agent
!pip install highway-env
# TODO: we use the bleeding edge version because the current stable version does not support the latest gym>=0.21 versions. Revert back to stable at the next SB3 release.
!pip install git+https://github.com/DLR-RM/stable-baselines3

# Environment
import gymnasium as gym
import highway_env

# Agent
from stable_baselines3 import DQN

# Visualization utils
%load_ext tensorboard
import sys
from tqdm.notebook import trange
!pip install tensorboardx gym pyvirtualdisplay
!apt-get install -y xvfb ffmpeg
!git clone https://github.com/Farama-Foundation/HighwayEnv.git 2> /dev/null
sys.path.insert(0, '/kaggle/working/HighwayEnv/scripts/')
from utils import record_videos, show_videos

In [None]:
import torch
from transformers import BitsAndBytesConfig

config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

torch.random.manual_seed(0)

llm_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    device_map="cuda",
    torch_dtype="auto",
    trust_remote_code=True,
    quantization_config = config
)

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")

In [None]:
def llm_action(prompt1, assist1, prompt2, last_act='FASTER'):
    messages = [{"role": "user", "content": prompt1},
           {"role": "assistant", "content": assist1},
           {"role": "user", "content": prompt2}]

    model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")

    output = llm_model.generate(model_inputs, max_new_tokens=2000, do_sample=True)

    decoded_output = tokenizer.batch_decode(output[:,model_inputs.size(1):], skip_special_tokens=True)
    
    try:
        action = decoded_output[0].strip().split('Final decision: ')[1].strip().split('\'')[0]
    except:
        action = last_act
    
    return action

In [None]:
from stable_baselines3 import DQN
import pprint
from matplotlib import pyplot as plt
import numpy as np

In [None]:
action_dict = {
    0: 'LANE_LEFT',
    1: 'IDLE',
    2: 'LANE_RIGHT',
    3: 'FASTER',
    4: 'SLOWER'
}

In [None]:
class MyHighwayEnv(gym.Env):
    def __init__(self, vehicleCount=10):
        super(MyHighwayEnv, self).__init__()
        # base setting
        self.vehicleCount = vehicleCount
        self.prev_action  = 'FASTER'
        
        # environment setting
        self.config = {
            "observation": {
                "type": "Kinematics",
                "features": ["presence", "x", "y", "vx", "vy"],
                "absolute": True,
                "normalize": False,
                "vehicles_count": vehicleCount,
                "see_behind": True,
            },
            "action": {
                "type": "DiscreteMetaAction",
                "target_speeds": np.linspace(0, 32, 9),
            },
            "duration": 40,
            "vehicles_density": 2,
            "show_trajectories": True,
            "render_agent": True,
        }
        self.env = gym.make("highway-fast-v0")
        self.env.configure(self.config)
        self.action_space = self.env.action_space
        self.observation_space = gym.spaces.Box(
            low=-np.inf,high=np.inf,shape=(10,5),dtype=np.float32
        )

    def find_smallest_positive(self, arr):
        smallest_positive = float('inf')
        index = -1

        for i, value in enumerate(arr):
            if 0 < value < smallest_positive:
                smallest_positive = value
                index = i

        return smallest_positive, index

    def prompt_design(self, obs_):

        prompt1 = 'You are a smart driving assistant. You, the \'ego\' car, are now driving on a highway. You need to recommend ONLY ONE best action among the following set of actions based on the current scenario: \n \
        \t1. IDLE -- maintain the current speed in the current lane \n \
        \t2. FASTER -- accelerate the ego vehicle \n \
        \t3. SLOWER -- decelerate the ego vehicle \n \
        \t4. LANE_LEFT -- change to the adjacent left lane \n \
        \t5. LANE_RIGHT -- change to the adjacent right lane\n'

        assist1 = 'Understood. Please provide the current scenario or conditions, such as traffic density, speed of surrounding vehicles, your current speed, and any other relevant information, so I can recommend the best action.'

        prompt2 = 'Here is the current scenario:\n \
        There are four lanes on the highway: Lane-1 (left most), Lane-2, Lane-3, Lane-4 (right most). \n\n'

        x, y, vx, vy = obs_[:,1], obs_[:,2], obs_[:,3], obs_[:,4]

        ego_x, ego_y   = x[0], y[0]
        ego_vx, ego_vy = vx[0], vy[0]

        veh_x, veh_y   = x[1:] - ego_x, y[1:] - ego_y
        veh_vx, veh_vy = vx[1:], vy[1:]

        lanes          = y//4+1
        ego_lane       = lanes[0]
        veh_lanes      = lanes[1:]

        if ego_lane == 1:
            ego_left_lane  = 'Left lane: Not available\n'
            ego_right_lane = 'Right lane: Lane-' + str(ego_lane+1) + '\n'
        elif ego_lane == 4:
            ego_left_lane  = 'Left lane: Lane-' + str(ego_lane-1) + '\n'
            ego_right_lane = 'Right lane: Not available\n'
        else:
            ego_left_lane  = 'Left lane: Lane-' + str(ego_lane-1) + '\n'
            ego_right_lane = 'Right lane: Lane-' + str(ego_lane+1) + '\n'

        prompt2 += 'Ego vehicle:\n \
        \tCurrent lane: Lane-' + str(ego_lane) + '\n' + '\t' + ego_left_lane + '\t' + ego_right_lane + '\tCurrent speed: ' + str(ego_vx) + ' m/s \n\n'

        lane_info = 'Lane info:\n'
        for i in range(4):
            inds     = np.where(veh_lanes == i+1)[0]
            num_v    = len(inds)
            if num_v > 0:
                val, ind = self.find_smallest_positive(veh_x[inds])
                true_ind = inds[ind]
                lane_info += '\tLane-' + str(i+1) + ': There are ' + str(num_v) + ' vehicle(s) in this lane ahead of ego vehicle, closest being ' + str(veh_x[true_ind]) + ' m ahead traveling at ' + str(veh_vx[true_ind]) + ' m\/s. \n'
            else:
                lane_info += '\tLane-' + str(i+1) + ' No other vehicle ahead of ego vehicle.\n'

        prompt2 += lane_info

        att_info = '\nAttention points:\n \
        \t1. SLOWER has least priority and should be used only when no other action is safe.\n \
        \t2. DO NOT change lanes frequently.\n \
        \t3. Safety is priority, but do not forget efficiency.\n \
        \t4. Your suggested action has to be one from one of the above five listed actions - IDLE, SLOWER, FASTER, LANE_LEFT, LANE_RIGHT. \n \
        Your last action was ' + self.prev_action + '. Please recommend action for the current scenario ONLY in the format \'Final decision: <final decision>\'.\n'

        prompt2 += att_info

        return prompt1, assist1, prompt2

    def step(self, action):
        
        # Step the wrapped environment and capture all returned values
        obs, dqn_reward, done, truncated, info = self.env.step(action)
        
        if np.random.rand() <= 0.33:
        
            prompt1, assist1, prompt2 = self.prompt_design(obs)

            action_llm = llm_action(prompt1, assist1, prompt2, self.prev_action).strip().split('.')[0]

            l_acts  = 0
            if 'LANE_LEFT' in action_llm:
                l_acts += 1
                act     = 'LANE_LEFT'
                llm_act = 0
            if 'IDLE' in action_llm:
                l_acts += 1
                act     = 'IDLE'
                llm_act = 1
            if 'LANE_RIGHT' in action_llm:
                l_acts += 1
                act     = 'LANE_RIGHT'
                llm_act = 2
            if 'FASTER' in action_llm:
                l_acts += 1
                act     = 'FASTER'
                llm_act = 3
            if 'SLOWER' in action_llm:
                l_acts += 1
                act     = 'SLOWER'
                llm_act = 4

            if l_acts == 1:
                if llm_act == action:
                    llm_reward = 1
                else:
                    llm_reward = 0
                comb_reward = 0.7*dqn_reward + 0.3*llm_reward
            else:
                comb_reward = dqn_reward
        else:
            comb_reward = dqn_reward
            
        self.prev_action = action_dict[action]
            
        Reward = 1 / (1 + np.exp(-comb_reward))
        
        return obs, Reward, done, truncated, info

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        return obs  # Make sure to return the observation

In [None]:
from stable_baselines3.common.vec_env import DummyVecEnv
env = MyHighwayEnv()

In [None]:
model = DQN('MlpPolicy', env,
            policy_kwargs=dict(net_arch=[256, 256]),
            learning_rate=5e-4,
            buffer_size=15000,
            learning_starts=200,
            batch_size=32,
            gamma=0.8,
            train_freq=1,
            gradient_steps=1,
            target_update_interval=50,
            exploration_fraction=0.7,
            verbose=1)

model.learn(int(2e4))

In [None]:
model.save('models/llm_model')

In [None]:
from gymnasium.wrappers import RecordVideo
# base setting
vehicleCount = 10

# environment setting
config = {
    "observation": {
        "type": "Kinematics",
        "features": ["presence", "x", "y", "vx", "vy"],
        "absolute": True,
        "normalize": True,
        "vehicles_count": vehicleCount,
        "see_behind": True,
    },
    "action": {
        "type": "DiscreteMetaAction",
        "target_speeds": np.linspace(0, 32, 9),
    },
    "duration": 40,
    "vehicles_density": 2,
    "show_trajectories": True,
    "render_agent": True,
}


env = gym.make('highway-v0', render_mode='rgb_array')
env.configure(config)
env = record_videos(env)
for episode in trange(3, desc='Test episodes'):
    (obs, info), done, truncated = env.reset(), False, False
    while not (done or truncated):
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, done, truncated, info = env.step(int(action))
env.close()
show_videos()