In [1]:
from datasets import load_dataset, Video

dataset = load_dataset("mmnist-dataset/huggingface-arrow-format/mmnist-easy", split='train')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import numpy as np
from PIL import Image, ImageDraw
from torchvision.ops import box_iou
import torch
from datasets import load_dataset

version = 'medium'

# Load the dataset
dataset = load_dataset(f"mmnist-dataset/huggingface-arrow-format/mmnist-{version}")

for idx in range(10):
    
    # Select the first training example
    example = dataset['train'][idx]
    
    video_frames = example['video']

    targets = [
        {
            'labels': example['bboxes_labels'][i], 
            'center_points': example['bboxes_keypoints'][i], 
            'bboxes': example['bboxes'][i],
            'amodal_bboxes': example['amodal_bboxes'][i],
            'track_ids': example['track_ids'][i],
        } for i in range(20)
    ]

    processed_frames = []
    processed_frames_no_annotations = []
    processed_frames_cp = []
    processed_frames_boxes = []
    processed_frames_boxes_cover = []
    processed_frames_boxes_border = []
    processed_frames_amodal = []

    colors = ['red', 'green', 'blue', 'yellow', 'purple', 'orange', 'pink', 'cyan', 'magenta', 'lime']
    
    hexs = (
        "042AFF",
        "0BDBEB",
        "F3F3F3",
        "00DFB7",
        "111F68",
        "FF6FDD",
        "FF444F",
        "CCED00",
        "00F344",
        "BD00FF",
        "00B4FF",
        "DD00BA",
        "00FFFF",
        "26C000",
        "01FFB3",
        "7D24FF",
        "7B0068",
        "FF1B6C",
        "FC6D2F",
        "A2FF0B",
        "FFB300",
        "FF00A0",
        "FF00B0",
        "FF00C0",
        "FF00D0",
        "FF00E0",
        "FF00F0",
        "FF00FF",
        "FF00FF",
    )
    
    print(f"Processing video with {len(video_frames)} frames...")
    
    for frame_idx in range(len(video_frames)):
        # Convert the frame to a numpy array and create a PIL Image
        frame_np = np.array(video_frames[frame_idx])
        # Ensure the data type is uint8 (assuming values are 0-255)
        if frame_np.dtype != np.uint8:
            frame_np = frame_np.astype(np.uint8)

        # Get frame dimensions
        frame_height, frame_width = frame_np.shape[:2]

        # Create regular visualizations
        frame_pil = Image.fromarray(frame_np)
        frame_pil_no_annotation = Image.fromarray(frame_np)
        frame_cp_pil = Image.fromarray(frame_np)
        frame_boxes_pil = Image.fromarray(frame_np)
        frame_boxes_cover_pil = Image.fromarray(frame_np)
        frame_boxes_border_pil = Image.fromarray(frame_np)

        # For amodal visualization, create a larger canvas with padding
        padding = 50  # Adjust padding as needed based on expected box sizes
        canvas_width = frame_width + (padding * 2)
        canvas_height = frame_height + (padding * 2)

        # Create a blank canvas with black background for amodal visualization
        amodal_canvas = np.zeros((canvas_height, canvas_width, 3), dtype=np.uint8)

        # Place the original frame in the center of the canvas
        amodal_canvas[padding:padding+frame_height, padding:padding+frame_width] = (
            frame_np if len(frame_np.shape) == 3 else np.stack([frame_np] * 3, axis=2)
        )

        # Create the amodal visualization on the expanded canvas
        frame_amodal_pil = Image.fromarray(amodal_canvas)

        if frame_pil.mode != "RGB":
            frame_pil = frame_pil.convert("RGB")
            frame_pil_no_annotation = frame_pil_no_annotation.convert("RGB")
            frame_cp_pil = frame_cp_pil.convert("RGB")
            frame_boxes_pil = frame_boxes_pil.convert("RGB")
            frame_boxes_cover_pil = frame_boxes_cover_pil.convert("RGB")
            frame_boxes_border_pil = frame_boxes_border_pil.convert("RGB")
            # Amodal image is already RGB because we created it with 3 channels
        draw = ImageDraw.Draw(frame_pil)
        draw_cp = ImageDraw.Draw(frame_cp_pil)
        draw_boxes = ImageDraw.Draw(frame_boxes_pil)
        draw_boxes_cover = ImageDraw.Draw(frame_boxes_cover_pil)
        draw_boxes_border = ImageDraw.Draw(frame_boxes_border_pil)
        draw_amodal = ImageDraw.Draw(frame_amodal_pil)

        # Get the current frame's targets
        current_target = targets[frame_idx]
        labels = current_target['labels']
        centers = current_target['center_points']
        bboxes = current_target['bboxes']
        amodal_bboxes = current_target['amodal_bboxes']

        track_ids = current_target['track_ids']
        
        bboxes_tensor = torch.tensor(bboxes, dtype=torch.float32)
        amodal_bboxes_tensor = torch.tensor(amodal_bboxes, dtype=torch.float32)
        
        if bboxes_tensor.shape[0] == 0:
            # If no bounding boxes, skip to the next frame
            processed_frames.append(frame_pil)
            processed_frames_no_annotations.append(frame_pil_no_annotation)
            processed_frames_cp.append(frame_cp_pil)
            processed_frames_boxes.append(frame_boxes_pil)
            processed_frames_boxes_cover.append(frame_boxes_cover_pil)
            processed_frames_boxes_border.append(frame_boxes_border_pil)
            processed_frames_amodal.append(frame_amodal_pil)
            continue
        
        # convert xywh to xyxy format
        bboxes_tensor = torch.cat(
            (bboxes_tensor[:, :2] - bboxes_tensor[:, 2:] / 2, 
             bboxes_tensor[:, :2] + bboxes_tensor[:, 2:] / 2), 
            dim=1
        )

        gt_iou_matrix = box_iou(bboxes_tensor, bboxes_tensor)
        gt_iou_matrix.fill_diagonal_(0)
        overlaps_exist = (gt_iou_matrix > 0.1).any(dim=1)
        overlap_gt_indices = torch.where(overlaps_exist)[0]
        
        for i, (label, center, bbox, amodal_bbox, track_id) in enumerate(zip(labels, centers, bboxes, amodal_bboxes, track_ids)):
            x, y, is_visible = center


            if is_visible == 2:
                # Convert to pixel coordinates (assuming center at 64,64)
                pixel_x = x
                pixel_y = y
                
                # Draw a red circle
                radius = 2
                
                draw.ellipse(
                    [(pixel_x - radius, pixel_y - radius),
                     (pixel_x + radius, pixel_y + radius)],
                    fill=f'#{hexs[track_id]}'
                )
                
                draw_cp.ellipse(
                    [(pixel_x - radius, pixel_y - radius),
                     (pixel_x + radius, pixel_y + radius)],
                    fill=f'#{hexs[track_id]}'
                )

            # Draw a bounding box
            bbox_x_min, bbox_y_min, w, h = bbox

            if not (bbox_x_min == -1 and bbox_y_min == -1 and w == -1 and h == -1):

                bbox_x_max = bbox_x_min + w
                bbox_y_max = bbox_y_min + h

                # Draw visible bounding box
                draw.rectangle(
                    [bbox_x_min, bbox_y_min, bbox_x_max, bbox_y_max],
                    outline=f'#{hexs[track_id]}',
                    width=1
                )
                draw_boxes.rectangle(
                    [bbox_x_min, bbox_y_min, bbox_x_max, bbox_y_max],
                    outline=f'#{hexs[track_id]}',
                    width=1
                )

            amodal_bbox_x_min, amodal_bbox_y_min, amodal_w, amodal_h = amodal_bbox
            amodal_bbox_x_max = amodal_bbox_x_min + amodal_w
            amodal_bbox_y_max = amodal_bbox_y_min + amodal_h

            # Adjust amodal bbox coordinates to account for padding in the expanded canvas
            padded_amodal_x_min = amodal_bbox_x_min + padding
            padded_amodal_y_min = amodal_bbox_y_min + padding
            padded_amodal_x_max = amodal_bbox_x_max + padding
            padded_amodal_y_max = amodal_bbox_y_max + padding

            # Draw the complete amodal bounding box on the expanded canvas
            draw_amodal.rectangle(
                [padded_amodal_x_min, padded_amodal_y_min, padded_amodal_x_max, padded_amodal_y_max],
                outline=f'#{hexs[track_id]}',
                width=1
            )

            # Mark the amodal box with "A" label
            draw_amodal.text(
                (padded_amodal_x_min + 1, padded_amodal_y_min + 1),
                "A",
                fill=f'#{hexs[track_id]}'
            )

            # Draw the original frame boundary as a reference
            draw_amodal.rectangle(
                [padding, padding, padding + frame_width - 1, padding + frame_height - 1],
                outline="white",
                width=1
            )

            text = str(label)
            
            if i in overlap_gt_indices:
                # Draw a cover box for overlapping ground truth
                draw_boxes_cover.rectangle(
                    [bbox_x_min, bbox_y_min, bbox_x_max, bbox_y_max],
                    outline=f'#{hexs[track_id]}',
                    width=1
                )
                draw_boxes_cover.text((bbox_x_min + 1, bbox_y_min + 1), text, fill=f'#{hexs[track_id]}')
                
            if bbox_x_min <= 0 or bbox_y_min <= 0 or bbox_x_max >= frame_pil.width-1 or bbox_y_max >= frame_pil.height-1:
                draw_boxes_border.rectangle(
                    [bbox_x_min, bbox_y_min, bbox_x_max, bbox_y_max],
                    outline=f'#{hexs[track_id]}',
                    width=1
                )
                draw_boxes_border.text((bbox_x_min + 1, bbox_y_min + 1), text, fill=f'#{hexs[track_id]}') 
            
            # Draw the label next to the point
            draw.text((bbox_x_min + 1, bbox_y_min + 1), text, fill=f'#{hexs[track_id]}')
            
            if is_visible:
                draw_cp.text((bbox_x_min + 1, bbox_y_min + 1), text, fill=f'#{hexs[track_id]}')
            draw_boxes.text((bbox_x_min + 1, bbox_y_min + 1), text, fill=f'#{hexs[track_id]}')
        
        processed_frames.append(frame_pil)
        processed_frames_no_annotations.append(frame_pil_no_annotation)
        processed_frames_cp.append(frame_cp_pil)
        processed_frames_boxes.append(frame_boxes_pil)
        processed_frames_boxes_cover.append(frame_boxes_cover_pil)
        processed_frames_boxes_border.append(frame_boxes_border_pil)
        processed_frames_amodal.append(frame_amodal_pil)

    print(f"Processed {len(processed_frames)} frames.")
    
    # Save as GIF
    processed_frames[0].save(
        f'./assets/annotated_video_{version}_{idx}.gif',
        save_all=True,
        append_images=processed_frames[1:],
        duration=200,  # Adjust duration between frames (ms)
        loop=0         # Loop indefinitely
    )
    processed_frames_cp[0].save(
        f'./assets/annotated_video_{version}_cp_{idx}.gif',
        save_all=True,
        append_images=processed_frames_cp[1:],
        duration=200,  # Adjust duration between frames (ms)
        loop=0         # Loop indefinitely
    )
    processed_frames_boxes[0].save(
        f'./assets/annotated_video_{version}_boxes_{idx}.gif',
        save_all=True,
        append_images=processed_frames_boxes[1:],
        duration=200,  # Adjust duration between frames (ms)
        loop=0         # Loop indefinitely
    )
    processed_frames_boxes_cover[0].save(
        f'./assets/annotated_video_{version}_boxes_cover_{idx}.gif',
        save_all=True,
        append_images=processed_frames_boxes_cover[1:],
        duration=200,  # Adjust duration between frames (ms)
        loop=0         # Loop indefinitely
    )
    processed_frames_boxes_border[0].save(
        f'./assets/annotated_video_{version}_boxes_border_{idx}.gif',
        save_all=True,
        append_images=processed_frames_boxes_border[1:],
        duration=200,  # Adjust duration between frames (ms)
        loop=0         # Loop indefinitely
    )
    processed_frames_amodal[0].save(
        f'./assets/annotated_video_{version}_amodal_{idx}.gif',
        save_all=True,
        append_images=processed_frames_amodal[1:],
        duration=200,  # Adjust duration between frames (ms)
        loop=0         # Loop indefinitely
    )

    # Save as GIF
    processed_frames_no_annotations[0].save(
        f'./assets/video_{version}_{idx}.gif',
        save_all=True,
        append_images=processed_frames_no_annotations[1:],
        duration=200,  # Adjust duration between frames (ms)
        loop=0         # Loop indefinitely
    )
    
    print(f"GIF {idx} created successfully!")

Generating train split: 15 examples [00:00, 692.71 examples/s]


[[36.0, 112.0, 10.0, 20.0], [45.0, 22.0, 14.0, 20.0], [107.0, 70.0, 20.0, 20.0], [14.0, 19.0, 13.0, 20.0], [82.0, 88.0, 20.0, 18.0]]
[[37.0, 112.0, 9.0, 16.0], [45.0, 22.0, 14.0, 20.0], [107.0, 70.0, 20.0, 20.0], [14.0, 19.0, 13.0, 20.0], [82.0, 88.0, 20.0, 18.0]]
[[37.54721450805664, 112.87699127197266, 10.0, 20.0], [37.0, 31.0, 14.0, 20.0], [102.0, 77.0, 20.0, 20.0], [8.0, 98.0, 14.0, 20.0], [-18.0, 97.1086196899414, 20.0, 20.0], [20.0, 19.0, 13.0, 20.0], [80.0, 81.0, 20.0, 18.0], [53.0, 61.0, 20.0, 20.0], [90.0, 111.0, 20.0, 20.0]]
[[40.0, 113.0, 8.0, 15.0], [37.0, 31.0, 14.0, 20.0], [102.0, 77.0, 20.0, 20.0], [8.0, 98.0, 14.0, 20.0], [0.0, 99.0, 2.0, 11.0], [20.0, 19.0, 13.0, 20.0], [80.0, 81.0, 20.0, 18.0], [53.0, 61.0, 20.0, 20.0], [92.0, 111.0, 20.0, 17.0]]
Processing video with 20 frames...
Processed 20 frames.
GIF 0 created successfully!
[[29.0, 64.0, 14.0, 20.0], [-5.902544975280762, 0.6457198262214661, 7.0, 20.0], [19.0, 12.0, 20.0, 20.0], [21.0, 58.0, 14.0, 20.0]]
[[29.0, 6