In [1]:
import rosbag2_py
import numpy as np
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image, LaserScan
from nav_msgs.msg import Odometry
from geometry_msgs.msg import Pose
from rclpy.serialization import deserialize_message
from rclpy.qos import QoSProfile, QoSDurabilityPolicy
from cv_bridge import CvBridge
import cv2
import os

class ImitationLearningDataset():
    def __init__(self, bag_file, output_dir, stop_threshold=0.0001, stop_duration_limit=6.0):
        self.bag_file = bag_file
        self.output_dir = output_dir
        self.bridge = CvBridge()
        self.stop_threshold = stop_threshold
        self.stop_duration_limit = stop_duration_limit

        # Containers for data
        self.images = []
        self.lasers = []
        self.odoms = []
        self.velocities = []
        self.tf_messages = []
        self.tf_static_messages = []
        self.timestamps = []

        # Open the bag file
        self.storage_options, self.converter_options = self.setup_rosbag_options(bag_file)
        self.reader = rosbag2_py.SequentialReader()
        self.reader.open(self.storage_options, self.converter_options)

        # Process messages
        self.process_bag()

    def setup_rosbag_options(self, bag_file):
        storage_options = rosbag2_py.StorageOptions(uri=bag_file, storage_id='sqlite3')
        converter_options = rosbag2_py.ConverterOptions(input_serialization_format='cdr', output_serialization_format='cdr')
        return storage_options, converter_options

    def process_bag(self):
        while self.reader.has_next():
            topic, msg, t = self.reader.read_next()

            if topic == "/camera/color/image_raw":
                self.process_image(msg, t)
            elif topic == "/scan":
                self.process_laser(msg, t)
            elif topic == "/odom":
                self.process_odom(msg, t)
            elif topic == "/tf":
                self.process_tf(msg, t)
            elif topic == "/tf_static":
                self.process_tf_static(msg, t)

    def process_image(self, msg, timestamp):
        msg = deserialize_message(msg, Image)
        cv_image = self.bridge.imgmsg_to_cv2(msg, "bgr8")
        self.images.append((cv_image, timestamp))
        self.timestamps.append(timestamp)

    def process_laser(self, msg, timestamp):
        msg_ = deserialize_message(msg, LaserScan)
        laser_data = np.array(msg_.ranges)
        self.lasers.append((laser_data, msg, timestamp))

    def process_odom(self, msg, timestamp):
        msg_ = deserialize_message(msg, Odometry)
        odom_data = (msg_.pose.pose.position.x, msg_.pose.pose.position.y, msg_.pose.pose.orientation.z)
        velocity = np.sqrt(msg_.twist.twist.linear.x**2 + msg_.twist.twist.linear.y**2)
        self.odoms.append((odom_data, msg, timestamp))
        self.velocities.append((velocity, msg_.twist.twist, timestamp))
        
    def process_tf(self, msg, timestamp):
        self.tf_messages.append((msg, timestamp))

    def process_tf_static(self, msg, timestamp):
        self.tf_static_messages.append((msg, timestamp))
        
    def align_data(self):
        print('aligning data')
        aligned_data = []

        # Align all data based on the closest timestamp
        for img, img_timestamp in self.images:
            closest_laser = min(self.lasers, key=lambda x: abs(x[-1] - img_timestamp))
            closest_odom = min(self.odoms, key=lambda x: abs(x[-1] - img_timestamp))
            closest_vel = min(self.velocities, key=lambda x: abs(x[-1] - img_timestamp))

            aligned_data.append([img, closest_laser, closest_odom, closest_vel, img_timestamp])

        return aligned_data

    def remove_stopped_data(self):
        moving_indices = [] 
        temp_moving_indices = []
        stopped_time = 0

        for i, (img, laser, odom, velocity, timestamp, goal) in enumerate(self.aligned_data):
            if velocity[0] > self.stop_threshold:
                if stopped_time <= self.stop_duration_limit * 1e9:
                    moving_indices.extend(temp_moving_indices)  
                    temp_moving_indices = []                  
                moving_indices.append(i)
                stopped_time = 0
            else:
                stopped_time += (self.aligned_data[i][4] - self.aligned_data[i - 1][4]) if i > 0 else 0
                if stopped_time <= self.stop_duration_limit * 1e9:
                    temp_moving_indices.append(i)

        # Filter data by moving indices
        self.filtered_aligned_data = [self.aligned_data[i] for i in moving_indices]

    def save_dataset(self):
        
        self.aligned_data = self.align_data()
        self.goal = self.calculate_goal_position(self.aligned_data)
        self.remove_stopped_data()
        self.save_filtered_bag(self.aligned_data)
        # Determine the goal position
        
      

    def calculate_goal_position(self, data):
        # Calculate the goal position based on odometry data
        target_distance = 1.0  # 1 meter away
        for i, (_, _, init_odom, _, init_timestamp) in enumerate(self.aligned_data):
            for j, (_, _, goal_odom, _, goal_timestamp) in enumerate(self.aligned_data[i:]):            
                distance = np.sqrt((goal_odom[0][0] - init_odom[0][0]) ** 2 + (goal_odom[0][1] - init_odom[0][1]) ** 2)
                if distance >= target_distance:
                    break
            self.aligned_data[i].append(goal_odom)
        

    def save_filtered_bag(self, data):
        print('Saving as bag file')
        # Create a new bag file for the filtered data
        output_bag_file = os.path.join(self.output_dir, "filtered_data.bag")
        writer = rosbag2_py.SequentialWriter()

        # Set up storage and converter options for writing
        storage_options = rosbag2_py.StorageOptions(uri=output_bag_file, storage_id='sqlite3')
        converter_options = rosbag2_py.ConverterOptions(input_serialization_format='cdr', output_serialization_format='cdr')
        writer.open(storage_options, converter_options)
        
        qos_profile_tf_static = "- history: 3\n  depth: 0\n  reliability: 1\n  durability: 1\n  deadline:\n    sec: 9223372036\n    nsec: 854775807\n  lifespan:\n    sec: 9223372036\n    nsec: 854775807\n  liveliness: 1\n  liveliness_lease_duration:\n    sec: 9223372036\n    nsec: 854775807\n  avoid_ros_namespace_conventions: false\n- history: 3\n  depth: 0\n  reliability: 1\n  durability: 1\n  deadline:\n    sec: 9223372036\n    nsec: 854775807\n  lifespan:\n    sec: 9223372036\n    nsec: 854775807\n  liveliness: 1\n  liveliness_lease_duration:\n    sec: 9223372036\n    nsec: 854775807\n  avoid_ros_namespace_conventions: false\n- history: 3\n  depth: 0\n  reliability: 1\n  durability: 1\n  deadline:\n    sec: 9223372036\n    nsec: 854775807\n  lifespan:\n    sec: 9223372036\n    nsec: 854775807\n  liveliness: 1\n  liveliness_lease_duration:\n    sec: 9223372036\n    nsec: 854775807\n  avoid_ros_namespace_conventions: false"


        # Create topics
        writer.create_topic(rosbag2_py.TopicMetadata(name='/camera/color/image_raw', type='sensor_msgs/msg/Image', serialization_format='cdr'))
        writer.create_topic(rosbag2_py.TopicMetadata(name='/scan', type='sensor_msgs/msg/LaserScan', serialization_format='cdr'))
        writer.create_topic(rosbag2_py.TopicMetadata(name='/odom', type='nav_msgs/msg/Odometry', serialization_format='cdr'))
        writer.create_topic(rosbag2_py.TopicMetadata(name='/goal_odom', type='nav_msgs/msg/Odometry', serialization_format='cdr'))
        writer.create_topic(rosbag2_py.TopicMetadata(name='/tf', type='tf2_msgs/msg/TFMessage', serialization_format='cdr'))
        writer.create_topic(rosbag2_py.TopicMetadata(name='/tf_static', type='tf2_msgs/msg/TFMessage', serialization_format='cdr', offered_qos_profiles=qos_profile_tf_static))
        

        # Write aligned data back to the bag
        for img, laser, odom, velocity, timestamp, goal_odom in data:
            # Convert the image back to a ROS Image message
            img_msg = self.bridge.cv2_to_imgmsg(img, "bgr8")
            img_msg.header.stamp = rclpy.time.Time(seconds=timestamp * 1e-9).to_msg()

            # Create LaserScan message
            laser_msg = deserialize_message(laser[1], LaserScan)
            laser_msg.header.stamp = rclpy.time.Time(seconds=timestamp * 1e-9).to_msg()

            # Create Odometry message
            odom_msg = deserialize_message(odom[1], Odometry)
            odom_msg.header.stamp = rclpy.time.Time(seconds=timestamp * 1e-9).to_msg()
            
            # Create Goal Odometry message
            goal_odom_msg = deserialize_message(goal_odom[1], Odometry)
            goal_odom_msg.header.stamp = rclpy.time.Time(seconds=timestamp * 1e-9).to_msg()

            # Write messages to the new bag file
            serialized_img = rclpy.serialization.serialize_message(img_msg)
            writer.write('/camera/color/image_raw', serialized_img, timestamp)
            serialized_laser = rclpy.serialization.serialize_message(laser_msg)
            writer.write('/scan', serialized_laser, timestamp)
            serialized_odom = rclpy.serialization.serialize_message(odom_msg)
            writer.write('/odom', serialized_odom, timestamp)
            serialized_goal_odom = rclpy.serialization.serialize_message(goal_odom_msg)
            writer.write('/goal_odom', serialized_goal_odom, timestamp)
        
        # Write tf and tf_static messages to the new bag file
        for tf_msg, tf_timestamp in self.tf_messages:
            writer.write('/tf', tf_msg, tf_timestamp)
        for tf_static_msg, tf_static_timestamp in self.tf_static_messages:
            writer.write('/tf_static', tf_static_msg, tf_static_timestamp)






In [2]:
bag_file = "/home/nigitha/corr2"  # Replace with your actual bag file path
output_dir = "/home/nigitha/ros2_ws_rnd/src/dataset"  # Replace with your output directory path

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

dataset_processor = ImitationLearningDataset(bag_file, output_dir)

[INFO] [1725613798.708316265] [rosbag2_storage]: Opened database '/home/nigitha/corr2/corr2_0.db3' for READ_ONLY.


In [3]:
dataset_processor.aligned_data = dataset_processor.align_data()
dataset_processor.calculate_goal_position(dataset_processor.aligned_data)
dataset_processor.remove_stopped_data() 

aligning data


In [5]:
dataset_processor.save_filtered_bag(dataset_processor.filtered_aligned_data)

Saving as bag file


[INFO] [1725615555.176338847] [rosbag2_storage]: Opened database '/home/nigitha/ros2_ws_rnd/src/dataset/filtered_data.bag/filtered_data.bag_0.db3' for READ_WRITE.
