# SJTU Small Traffic Light Dataset (S2TLD) Extraction
Script to extract square images of traffic lights from SJTU Small Traffic Light Dataset. The dataset can be found here:

https://github.com/Thinklab-SJTU/S2TLD

In [None]:
import pandas as pd
import os
from PIL import Image
from tqdm import tqdm
import re
import xml.etree.ElementTree as ET
import re

## Constants

In [None]:
# Ensure the output data directory exists before running this script
split_classes = True

# DATA_DIR = r'D:\Data\Datasets\S2TLD\1080p'
# OUTPUT_DATA_DIR = 'D:\Data\Datasets\S2TLD_extracted'

DATA_DIR = r'D:\Data\Datasets\S2TLD\720p'
OUTPUT_DATA_DIR = 'D:\Data\Datasets\S2TLD_extracted'

## Load Dataset

In [None]:
# Function to parse XML and extract data
def parse_annotation(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    # Extract image properties
    filename = root.find('filename').text
    folder = root.find('folder').text if root.find('folder') is not None else None
    database = root.find('database').text if root.find('database') is not None else None
    annotation_info = root.find('annotation').text if root.find('annotation') is not None else None
    image_info = root.find('image').text if root.find('image') is not None else None
    segmented = root.find('segmented').text if root.find('segmented') is not None else None

    size_element = root.find('size')
    size = {
        'height': int(size_element.find('height').text),
        'width': int(size_element.find('width').text),
        'depth': int(size_element.find('depth').text)
    } if size_element is not None else None

    # List to hold all objects' data
    objects_data = []

    # Iterate through all objects in the XML
    for obj in root.findall('object'):
        name = obj.find('name').text
        bndbox = obj.find('bndbox')
        xmin = int(bndbox.find('xmin').text)
        ymin = int(bndbox.find('ymin').text)
        xmax = int(bndbox.find('xmax').text)
        ymax = int(bndbox.find('ymax').text)

        object_data = [filename, folder, database, annotation_info, image_info, size['height'], size['width'], size['depth'], segmented, name, xmin, ymin, xmax, ymax]
        objects_data.append(object_data)

    return objects_data

# Process all XML files in the Annotations directory
annotations_path = os.path.join(DATA_DIR, 'Annotations')
all_objects = []

for xml_file in tqdm(os.listdir(annotations_path), desc="Processing XML files"):
    if xml_file.endswith('.xml'):
        xml_path = os.path.join(annotations_path, xml_file)
        objects_data = parse_annotation(xml_path)
        all_objects.extend(objects_data)

# Create DataFrame with all fields
columns = ['Filename', 'Folder', 'Database', 'Annotation Info', 'Image Info', 'Image Height', 'Image Width', 'Image Depth', 'Segmented', 'Annotation tag', 'Upper left corner X', 'Upper left corner Y', 'Lower right corner X', 'Lower right corner Y']
df = pd.DataFrame(all_objects, columns=columns)

# Compute traffic light width and height columns
df['width'] = df['Lower right corner X'] - df['Upper left corner X']
df['height'] = df['Lower right corner Y'] - df['Upper left corner Y']

In [None]:
# Display the first few rows of the DataFrame
df.head()

In [None]:
# Describe dataframe statistics
df.describe()

## Extract Images

In [None]:
# Parameters
image_size = 80  # Extracted square image side lengths
height_thresh = 40  # Minimum traffic light height
scale = image_size / height_thresh

# Ensure the directory exists
classes = df['Annotation tag'].unique()
if split_classes:
    for cls in classes:
        output_dir = os.path.join(OUTPUT_DATA_DIR, cls)
        os.makedirs(output_dir, exist_ok=True)
else:
    output_dir = os.path.join(OUTPUT_DATA_DIR, "all")
    os.makedirs(output_dir, exist_ok=True)

# Filter for largest traffic lights and group by filename
filtered_df = df[df['height'] >= height_thresh]
grouped = filtered_df.groupby('Filename')

# Print number of instances:
print(f"Found {len(filtered_df)} instances")

# Iterate through each group
for filepath, group in tqdm(grouped, desc="Processing images"):
    # Get image file path
    basename = os.path.basename(filepath)
    filenum = re.search(r'(\d+)\.jpg$', basename).group(1)
    filename = os.path.join(DATA_DIR, "JPEGImages", basename)
    
    # Get bounding box coordinates and class
    bounding_boxes = group[['Upper left corner X', 'Upper left corner Y', 'Lower right corner X', 
                            'Lower right corner Y']].values.tolist()
    classes = group['Annotation tag'].values.tolist()

    try:
        pattern = r'(\d{4}-\d{2}-\d{2}) (\d{2}):(\d{2}):(\d{2}\.\d+)'
        new_filename = re.sub(pattern, r'\1 \2_\3_\4', filename)
        img = Image.open(new_filename)
    except Exception as e:
        print(e)
        continue

    for i, box in enumerate(bounding_boxes):
        # Access columns by name
        cls = classes[i]

        # Calculating the center of the bounding box
        center_x = (box[0] + box[2]) / 2
        center_y = (box[1] + box[3]) / 2

        # Determine largest traffic light dimension and use to determine the size of the square to be extracted
        box_width = box[2] - box[0]
        box_height = box[3] - box[1]
        largest_dimension = max(box_width, box_height) * scale

        # Creating a new square bounding box
        half_size = largest_dimension / 2
        new_box = [
            max(center_x - half_size, 0), # left
            max(center_y - half_size, 0), # upper
            min(center_x + half_size, img.width), # right
            min(center_y + half_size, img.height) # lower
        ]

        # Cropping the image
        cropped_img = img.crop(new_box)

        # Resizing the image to 64x64
        cropped_img = cropped_img.resize((image_size, image_size))

        # Constructing the filename using the counter
        filename = f"{filenum}_{i}.jpg"
        if split_classes:
            file_path = os.path.join(OUTPUT_DATA_DIR, cls, filename)
        else:
            file_path = os.path.join(OUTPUT_DATA_DIR, "all", filename)

        # Save the cropped image
        cropped_img.save(file_path)

## Filtered Dataset Statistics
Statistics describing the dataset after annotations which do not contain traffic lights that meet the height threshold are dropped

In [None]:
filtered_df['Annotation tag'].value_counts()