In [3]:
import torch 
import torch.nn as nn
import matplotlib.pyplot as plt
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    Mapping,
    Optional,
    Tuple,
    Type,
    Union,
)
import numpy as np
import pickle
import gymnasium as gym
from gymnasium import spaces
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [5]:
import sys
sys.path.append('../../cleanrl/')
from radar_maps.env.radar_map_double_integrator import RadarMap_DoubleIntegrator

In [6]:
data_path = "/home/lucas/workspace/cleanrl/cleanrl/data_multimodal/"
data_num = 1
num_mode = 3

In [7]:
class Trajectory:
    def __init__(self, obs, actions):
        self.obs = obs
        self.actions = actions

        
class TrajDataset(Dataset):
    def __init__(self, trajs):
        states = []
        actions = []
        for traj in trajs:
            states.append(traj.obs)
            actions.append(traj.actions)
        self.states = np.concatenate(states, axis=0)
        self.actions = np.concatenate(actions, axis=0)

    def __len__(self):
        return self.states.shape[0]

    def __getitem__(self, idx):
        sample = dict()
        sample['state'] = self.states[idx]
        sample['action'] = self.actions[idx]
        return sample
    
    def add_traj(self, traj=None, states=None, actions=None):
        if traj is not None:
            self.states = np.concatenate((self.states, traj.obs), axis=0)
            self.actions = np.concatenate((self.actions, traj.actions), axis=0)
        else:
            self.states = np.concatenate((self.states, states), axis=0)
            self.actions = np.concatenate((self.actions, actions), axis=0)    

In [9]:
def get_path_chunk(path, num_points=3):
    len_path = len(path)
    if len_path >= num_points:
        return list(np.hstack(path[:num_points]))
    
    path_patched = list(np.hstack(path))
    for _ in range(num_points-len_path):
        path_patched.extend(list(path[-1]))
    # print("Patched path: ", path_patched)
    return path_patched


def get_radar_heat_map(state, radar_locs, img_size, radar_detection_range, grid_size):
    radars_encoding = np.zeros((img_size, img_size))
    theta = np.arctan2(state[3], state[1])
    # theta = 0.0
    loc_to_glob = np.array([[np.cos(theta), -np.sin(theta), state[0]],
                            [np.sin(theta), np.cos(theta), state[2]],
                            [0., 0., 1.]])
    glob_to_loc = np.linalg.inv(loc_to_glob)
    # print(glob_to_loc)
    for radar_loc in radar_locs:
        if abs(state[0] - radar_loc[0]) < radar_detection_range or abs(state[2] - radar_loc[1]) < radar_detection_range:
            # print("Radar global: ", radar_loc)
            glob_loc_hom = np.array([radar_loc[0], radar_loc[1], 1])
            local_loc_hom = np.dot(glob_to_loc, glob_loc_hom)
            radars_loc_coord = local_loc_hom[:2]
            # print('Global: ', radar_loc)
            # print("Local: ", radars_loc_coord)
            # print("Radars local coord: ", radars_loc_coord)
            y_grid = np.rint((radars_loc_coord[1]) / grid_size) 
            x_grid = np.rint((radars_loc_coord[0]) / grid_size) 
            # print("Grid index: ", [x_grid, y_grid])
            # print()
            for i in range(-int(img_size/2), int(img_size/2)):
                for j in range(-int(img_size/2), int(img_size/2)):
                    radars_encoding[int(i + img_size/2), int(j + img_size/2)] += np.exp((-(x_grid - i)**2 - (y_grid - j)**2)/2.0)*1e6

    plt.imsave('heat_map.jpg', radars_encoding.T, cmap='hot', origin='lower')
    heat_map_img = plt.imread('heat_map.jpg')

    return heat_map_img

def generate_training_data(traj, ctr, path_mm, radars, detection_range, grid_size, v_lim, u_lim):
    observations = []
    actions = []
    for i in range(len(traj)):
        x_cur = traj[i]

        heat_map_img = get_radar_heat_map(x_cur, radars, 2*int(detection_range/grid_size), detection_range, grid_size)
        # print(heat_map_img.shape)
        x_cur_normalized = np.array([x_cur[0]/1200.0, x_cur[1]/v_lim, x_cur[2]/1200.0, x_cur[3]/v_lim])

        observation = {"state": x_cur_normalized, "img": heat_map_img}
        observations.append(observation)
        if i < len(traj) - 1:
            action_prediction = []
            for m in range(num_mode):
                action_prediction.extend(ctr[i, 2*m:2*(m+1)]/u_lim)
                path_tmp = path_mm[num_mode*i + m]
                path_tmp = [x / 1200.0 for x in path_tmp]
                # print(path_tmp)
                action_prediction.extend(get_path_chunk(path_tmp))
            actions.append(action_prediction)
    
    return np.array(observations), np.array(actions)

def process_data(detection_range, grid_size, v_lim, u_lim):
    bc_data = []
    for i in range(data_num):
        print("Processing data: ", i)
        traj = np.load(data_path + f'state_traj_{i}.npy')
        control = np.load(data_path + f'control_traj_{i}.npy')
        radar_config = np.load(data_path + f'radar_config_{i}.npy')

        with open(data_path+ f'nominal_path_multimodal_{i}.pkl', 'rb') as handle:
            path_mm = pickle.load(handle)

        obs, acts = generate_training_data(traj, control, path_mm, radar_config, detection_range, grid_size, v_lim, u_lim)
        bc_traj = Trajectory(obs, acts)
        bc_data.append(bc_traj)
    return TrajDataset(bc_data)