In [None]:
pwd

In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [3]:
import matplotlib.pyplot as plt
import numpy as np
import os
import tempfile
from IPython.display import HTML
from base64 import b64encode

from vicas.dataset import ViCaSDataset, ViCaSVideo
from vicas.caption_parsing import parse_caption

## Dataset API

`ViCaSDataset` is a wrapper class to easily iterate over all videos.

**TODO:** Set `annotations_dir` to the directory path where all the JSON annotations are saved

In [None]:
annotations_dir = "/data/dataset/ViCaS/annotations/v0.1"
video_frames_dir = "/data/dataset/ViCaS/video_frames/"
split = 'train' 
dataset = ViCaSDataset(
    annotations_dir, 
    split=split,
    video_frames_dir=video_frames_dir
)
print(f"Indexed {len(dataset)} videos from the dataset")
vid_to_json = dataset.video_id_to_json 
vid_splits_train = dataset.get_split_videos('train')

Indexed 5131 videos from the dataset


In [None]:
example_video_id = 0  
video = dataset.parse_video(example_video_id)

In [6]:
json_file = vid_to_json[example_video_id]
import json
with open(json_file, 'r') as fh:
    content = json.load(fh)

In [None]:
caption_raw = content['caption_raw_en']
caption_obj = parse_caption(caption_raw)
print(caption_obj.parsed)

id_phrase_pairs = [(obj.ids[0], obj.phrase) for obj in caption_obj.objects]
print(id_phrase_pairs)
# [(1, 'A tattooed man in shorts'), (2, 'rubber band'), (2, 'rubber band')]


A tattooed man in shorts is exercising with a rubber band in the bedroom when one end of the rubber band suddenly loosens and hits his crotch. He covers his crotch and walks to the right in pain.
[(1, 'A tattooed man in shorts'), (2, 'rubber band'), (2, 'rubber band')]


In [8]:
instance_ids = []
for i in content['object_referrals'] :
    instance_ids.extend(i['track_ids'])
print(set(instance_ids))

{1, 2}


In [None]:
import random
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

def overlay_mask_on_image(image, mask, mask_opacity=0.6, mask_color=(0, 255, 0), border_thickness=0):
    if mask.ndim == 3:
        assert mask.shape[2] == 1
        _mask = mask.squeeze(axis=2)
    else:
        _mask = mask

    mask_bgr = np.stack((_mask, _mask, _mask), axis=2)
    masked_image = np.where(mask_bgr > 0, mask_color, image)
    overlayed_image = ((mask_opacity * masked_image) + ((1. - mask_opacity) * image)).astype(np.uint8)

    if border_thickness > 0:
        _mask = _mask.astype(np.uint8)
        assert border_thickness % 2 == 1  # odd number
        kernel = np.ones((border_thickness, border_thickness), np.uint8)
        edge_mask = cv2.dilate(_mask, kernel, iterations=1) - _mask
        edge_mask = np.stack([edge_mask, edge_mask, edge_mask], axis=2)
        mask_color = np.array(mask_color, np.uint8)[None, None, :]
        mask_color = np.repeat(mask_color, image.shape[0], 0)
        mask_color = np.repeat(mask_color, image.shape[1], 1)
        overlayed_image = np.where(edge_mask > 0, mask_color, overlayed_image)

    return overlayed_image

video_frames_dir = "/data/dataset/ViCaS/video_frames"
for prompt, masks, track_ids, filenames in video.parse_lgvis(include_auto_annotated_masks=False, return_viz=False):
    num_frames = len(filenames)
    num_bins = 4
    bin_size = num_frames // num_bins
    selected_indices = []

    print(prompt)
    print(track_ids)
    for i in range(num_bins):
        start_idx = i * bin_size
        end_idx = (i + 1) * bin_size if i != num_bins - 1 else num_frames
        selected_indices.append(random.randint(start_idx + 1, end_idx - 1))

    fig, axes = plt.subplots(1, num_bins, figsize=(20, 8))
    for i, idx in enumerate(selected_indices):
        frame_path = os.path.join(video_frames_dir, filenames[idx])
        image = cv2.imread(frame_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 

        mask_list = masks[idx] 
        if isinstance(mask_list, list):
            if len(mask_list) > 0:
                mask = mask_list[0] 
            else:
                print(f"⚠️ Warning: No masks available for frame {idx}")
                continue
        else:
            mask = mask_list 

        if mask.ndim == 3 and mask.shape[0] == 1:
            mask = mask.squeeze(0)

        overlayed_image = overlay_mask_on_image(image, mask, mask_opacity=0.6, mask_color=(255, 0, 0))

        axes[i].imshow(overlayed_image)
        axes[i].set_title(f"Frame {idx}")
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
os.environ['OPENAI_API_KEY'] = ""
os.environ['GEMINI_API_KEY'] = ""

### ViCaS benchmark

In [None]:
import os
import json
import openai
import math
import google.generativeai as genai
from collections import defaultdict
from vicas.caption_parsing import parse_caption

def refine_action_based_caption(original_caption, obj_description, frame_name):

    prompt = f"""
        You are an AI assistant that specializes in generating **concise**, action-centric, human-centered video descriptions.
        Your task is to rewrite the given caption to focus on the **main visible action** of the object in the scene.

        **Context:**
        - The main object is: {obj_description}

        **Instructions:**
        - Focus on **only one** clear and **visually observable** action the object is performing.
        - If multiple actions are described (e.g., "turns and kicks"), select the most salient or defining action, and describe it **precisely**.
        - Replace vague, mental, or goal-oriented verbs (e.g., "try," "avoid," "prepare") with concrete, visible actions.
        - Eliminate redundant elements such as consequences, causal structures ("causing", "when"), or position-only descriptions.
        - Avoid brackets (`[]`), slashes (`/`), or placeholders (`<maskX>`).
        - Keep the tone objective and the sentence short.

        **Original Caption:** {original_caption}

        **Now, rewrite the caption to clearly describe the object’s main visible action.**
        """

    try:
        response = openai.ChatCompletion.create(
            model="gpt-4o",
            messages=[{"role": "system", "content": "You are a helpful assistant."},
                      {"role": "user", "content": prompt}]
        )
        
        caption = response["choices"][0]["message"]["content"].strip()
        return caption
    
    except Exception as e:
        print(f"Error in GPT-4o API call: {e}")
        return original_caption 

import google.generativeai as genai
import os

genai.configure(api_key=os.getenv("GEMINI_API_KEY"))

def is_movable_object(obj_description):

    static_objects = ["rubber band", "table", "chair", "cup", "bottle", "lamp", "ball", "tree", "car", "building", "book", "box", "door"]
    
    human_keywords = ["man", "woman", "person", "people", "child", "boy", "girl", "baby"]

    if "'s" in obj_description.lower():
        return False

    if any(word in obj_description.lower() for word in human_keywords):
        return True  

    if obj_description.lower() in static_objects:
        return False

    prompt = f"""
    Determine if the object is MOVABLE or STATIC.

    - MOVABLE: Humans, animals, self-moving objects.
    - STATIC: Vehicles, bicycles, objects requiring external force.
    - **If the object description contains a person (e.g., "man in plaid shirts"), it should always be MOVABLE.**
    - **If the object contains "'s" (e.g., "man's shoulder", "dog's tail"), it should always be STATIC.**
    - If the object is **clothing (pants, shirt, shoes, socks, etc.), classify it as STATIC.**
    - Answer only "MOVABLE" or "STATIC".

    **Object to classify:** {obj_description}
    """

    try:
        model = genai.GenerativeModel("gemini-2.0-flash-lite")
        response = model.generate_content(prompt)

        if response and response.text:
            classification = response.text.strip().upper()
            if "MOVABLE" in classification:
                print(f"✅ {obj_description} is MOVABLE!")
                return True
            else:
                print(f"❌ {obj_description} is STATIC.")
                return False
        else:
            print(f"Error: Could not classify {obj_description}, assuming STATIC.")
            return False  

    except Exception as e:
        print(f"Gemini API 호출 오류: {e}")
        return False

def get_object_description_from_caption(id_phrase_pairs, obj_id):
    for (caption_obj_id, phrase) in id_phrase_pairs:
        if caption_obj_id == obj_id:
            return phrase  
    return "Unknown Object" 

annotations_dir = "/data/dataset/ViCaS/annotations/v0.1"
video_frames_dir = "/data/dataset/ViCaS/video_frames/"
split = 'train'

dataset = ViCaSDataset(
    annotations_dir, 
    split=split,
    video_frames_dir=video_frames_dir
)

total_videos = len(dataset.video_id_to_json)
sample_size = math.ceil(total_videos / 10)

current_batch = 5

start_idx = (current_batch - 1) * sample_size
end_idx = min(start_idx + sample_size - 1, total_videos - 1)

video_ids = list(dataset.video_id_to_json.keys())[start_idx:end_idx]

print(f"✅ Processing batch {current_batch}/10: {start_idx} to {end_idx} (Total {total_videos})")
print(f"✅ Selected {len(video_ids)} video IDs")

output_json = defaultdict(lambda: defaultdict(dict))
samples = []

for video_idx, video_id in enumerate(video_ids, start=1):
    json_file = dataset.video_id_to_json[video_id]
    
    with open(json_file, 'r') as fh:
        content = json.load(fh)  

    caption_raw = content['caption_raw_en']
    caption_obj = parse_caption(caption_raw)  
    print(f"\nProcessing Video {video_idx}/{len(video_ids)} ({(video_idx/len(video_ids))*100:.1f}%) - ID: {video_id}")
    print(f"Caption: {caption_raw}")

    id_phrase_pairs = [(obj.ids[0], obj.phrase) for obj in caption_obj.objects]
    print(f"Extracted Objects: {id_phrase_pairs}")

    video = dataset.parse_video(video_id)

    for frame_idx, frame_info in enumerate(video.parse_lgvis(include_auto_annotated_masks=False, return_viz=False), start=1):
        prompt, masks, track_ids, filenames = frame_info

        print(f"Frame {frame_idx} in Video {video_id}, Track IDs: {track_ids}")

        if not track_ids:
            print(f"Warning: No valid track IDs found for Frame {frame_idx}")
            continue

        for filename in filenames:
            frame_name = os.path.basename(filename)
            frame_id = int(frame_name.split('.')[0])

            instance_dict = {}
            unique_instance_ids = sorted(set(track_ids))
            valid_instance_ids = []

            for obj_id in unique_instance_ids:
                obj_description = get_object_description_from_caption(id_phrase_pairs, obj_id)  
                print(f"Object: {obj_description} (ID: {obj_id})")

                if obj_description == "Unknown Object":
                    continue  

                if is_movable_object(obj_description):  
                    valid_instance_ids.append(obj_id)

            if not valid_instance_ids:
                print(f"Warning: No valid moving instances found in Frame {frame_id}")
                continue

            for obj_id in valid_instance_ids:
                obj_description = get_object_description_from_caption(id_phrase_pairs, obj_id)

                if not is_movable_object(obj_description):  # STATIC이면 건너뜀
                    continue

                refined_caption = refine_action_based_caption(caption_raw, obj_description, frame_name)

                print(f"✅ ({video_idx}/{len(video_ids)}) Frame {frame_idx}: Generated Caption: {refined_caption}")

                instance_dict[str(obj_id)] = refined_caption  

            if video_id not in output_json:
                output_json[video_id] = {}

            if frame_id not in output_json[video_id]:
                output_json[video_id][frame_id] = {
                    "file_name": frame_name,
                    "instances": [],
                    "sentences": {}
                }

            output_json[video_id][frame_id]["instances"].extend(instance_dict.keys())
            output_json[video_id][frame_id]["sentences"].update(instance_dict)

            samples.append({
                "Video ID": video_id,
                "Frame ID": frame_id,
                "Instances": list(instance_dict.keys()),
                "Action-Focused Captions": instance_dict
            })

            print(f"Progress: {len(samples)} samples processed.")

output_json_path = "vicas_action_captions_5.json"
with open(output_json_path, "w", encoding="utf-8") as f:
    json.dump(output_json, f, ensure_ascii=False, indent=4)

print(f"✅ JSON saved: {output_json_path}")

### ViCaS 각자 파트 json 생성 + 시각화 코드

In [None]:
# import json
# import os
# import textwrap
# import numpy as np
# from PIL import Image, ImageDraw, ImageFont
# from IPython.display import display
# from vicas.dataset import ViCaSDataset

# def overlay_mask_with_labels(image, masks, instance_ids, frame_data, video_id, frame_id, opacity=0.5, frame_index=None, total_frames=None):
#     image_np = np.array(image).copy()
#     overlay = image_np.copy()
#     font = ImageFont.load_default()

#     for idx, (mask, inst_id) in enumerate(zip(masks, instance_ids), 1):
#         if mask.ndim == 3:
#             mask = mask.squeeze()
#         color = [np.random.randint(100, 255) for _ in range(3)]
#         for c in range(3):
#             overlay[:, :, c] = np.where(mask == 1, color[c], overlay[:, :, c])

#     overlay_pil = Image.fromarray(overlay)

#     overlay_np = np.array(overlay_pil)
#     blended = (opacity * overlay_np + (1 - opacity) * image_np).astype(np.uint8)
#     overlay_pil = Image.fromarray(blended)
#     draw = ImageDraw.Draw(overlay_pil)

#     for idx, (mask, inst_id) in enumerate(zip(masks, instance_ids), 1):
#         if mask.ndim == 3:
#             mask = mask.squeeze()
#         ys, xs = np.where(mask == 1)
#         if len(xs) > 0 and len(ys) > 0:
#             cx, cy = int(np.mean(xs)), int(np.mean(ys))
#             draw.text((cx, cy), f"Instance {inst_id}", fill=(255, 0, 0), font=font)

#     width, height = overlay_pil.size
#     if frame_index is not None and total_frames is not None:
#         header = f"Video ID: {video_id} | Frame: {frame_id} ({frame_index + 1} out of {total_frames})"
#     else:
#         header = f"Video ID: {video_id} | Frame: {frame_id}"
#     header_lines = [header]

#     for inst_id in frame_data["instances"]:
#         sentence = frame_data["sentences"].get(inst_id, "")
#         wrapped = textwrap.wrap(f"Instance {inst_id}: {sentence}", width=80)
#         header_lines.extend(wrapped)

#     line_height = 15
#     padding = 10
#     header_height = padding * 2 + line_height * len(header_lines)

#     final_image = Image.new("RGB", (width, height + header_height), color=(255, 255, 255))
#     final_image.paste(overlay_pil, (0, header_height))

#     draw = ImageDraw.Draw(final_image)
#     y_offset = padding
#     for line in header_lines:
#         draw.text((padding, y_offset), line, fill=(0, 0, 0), font=font)
#         y_offset += line_height

#     return final_image

# annotations_dir = "/data/dataset/ViCaS/annotations/v0.1"
# video_frames_dir = "/data/dataset/ViCaS/video_frames"
# split = 'train'
# json_path = "vicas_all.json"
# output_dir = "./vicas_output"
# os.makedirs(output_dir, exist_ok=True)

# dataset = ViCaSDataset(annotations_dir, split=split, video_frames_dir=video_frames_dir)

# with open(json_path, "r", encoding="utf-8") as f:
#     full_data = json.load(f)

# video_ids = list(full_data.keys())
# total_videos = len(video_ids)
# samples_per_part = [600, 1100, 200, 300, 195]
# assert sum(samples_per_part) <= total_videos, "지정한 샘플 수가 전체보다 많습니다."

# selected_part = int(input("파트 번호를 입력하세요 (1 ~ 5): "))
# assert 1 <= selected_part <= 5, "올바른 파트 번호(1 ~ 5)를 입력하세요."

# start_idx = sum(samples_per_part[:selected_part - 1])
# end_idx = start_idx + samples_per_part[selected_part - 1]
# selected_video_ids = video_ids[start_idx:end_idx]

# print(f"총 {total_videos}개 중 {len(selected_video_ids)}개 선택됨 (파트 {selected_part})")

# part_data = {vid: full_data[vid] for vid in selected_video_ids}
# output_json_path = f"vicas_part{selected_part}.json"
# with open(output_json_path, "w", encoding="utf-8") as f:
#     json.dump(part_data, f, ensure_ascii=False, indent=4)
# print(f"✅ JSON saved: {output_json_path}")

# for vid_idx, vid in enumerate(selected_video_ids, 1): 
#     video = dataset.parse_video(int(vid))
#     frames = part_data[vid]
#     sorted_frame_ids = sorted([int(fid) for fid in frames.keys()])

#     for frame_index, frame_id in enumerate(sorted_frame_ids):
#         frame_data = frames[str(frame_id)]
#         file_name = frame_data["file_name"]
#         image_path = os.path.join(video_frames_dir, str(vid).zfill(6), file_name)

#         if not os.path.exists(image_path):
#             print(f"⚠️ 이미지 없음: {image_path}")
#             continue

#         image = Image.open(image_path).convert("RGB")
#         instances = frame_data["instances"]
#         instance_ids = [int(i) for i in instances]

#         matched_masks = []
#         matched_ids = []
#         found = False

#         for prompt, masks, track_ids, filenames in video.parse_lgvis(include_auto_annotated_masks=False, return_viz=False):
#             for idx, fname in enumerate(filenames):
#                 if os.path.basename(fname) == file_name:
#                     for inst_id in instance_ids:
#                         if inst_id in track_ids:
#                             inst_index = track_ids.index(inst_id)
#                             matched_masks.append(masks[idx][inst_index])
#                             matched_ids.append(inst_id)

#                     if matched_masks:
#                         vis_img = overlay_mask_with_labels(
#                             image, matched_masks, matched_ids,
#                             frame_data, vid, frame_id,
#                             frame_index=frame_index, total_frames=len(sorted_frame_ids)
#                         ,opacity=0.5
#                         )
#                         display(vis_img)

#                     found = True
#                     break
#             if found:
#                 break

In [None]:
# 이 코드는 이미지만 시각화! 실행 x
# import json
# import os
# import textwrap
# from IPython.display import display
# from PIL import Image, ImageDraw, ImageFont

# json_path = "vicas_all.json"
# image_dir = "/data/dataset/ViCaS/video_frames/"
# output_dir = "./vicas_output"
# os.makedirs(output_dir, exist_ok=True)

# with open(json_path, "r", encoding="utf-8") as f:
#     full_data = json.load(f)

# video_ids = list(full_data.keys())
# total_videos = len(video_ids)

# samples_per_part = [600, 1100, 200, 300, 195]
# assert sum(samples_per_part) <= total_videos, "지정한 샘플 수가 전체보다 많습니다."

# selected_part = int(input("파트 번호를 입력하세요 (1 ~ 5): "))
# assert 1 <= selected_part <= 5, "올바른 파트 번호(1 ~ 5)를 입력하세요."

# start_idx = sum(samples_per_part[:selected_part - 1])
# end_idx = start_idx + samples_per_part[selected_part - 1]
# selected_video_ids = video_ids[start_idx:end_idx]

# print(f"총 {total_videos}개 중 {len(selected_video_ids)}개 선택됨 (파트 {selected_part})")

# part_data = {vid: full_data[vid] for vid in selected_video_ids}
# output_json_path = f"vicas_part{selected_part}.json"
# with open(output_json_path, "w", encoding="utf-8") as f:
#     json.dump(part_data, f, ensure_ascii=False, indent=4)
# print(f"✅ JSON saved: {output_json_path}")

# font = ImageFont.load_default()

# for vid_idx, vid in enumerate(selected_video_ids, 1):
#     frames = part_data[vid]
#     sorted_frame_ids = sorted([int(fid) for fid in frames.keys()])

#     for frame_id in sorted_frame_ids:
#         frame_data = frames[str(frame_id)]
#         file_name = frame_data["file_name"]

#         video_id_str = str(vid).zfill(6)
#         image_path = os.path.join(image_dir, video_id_str, file_name)

#         if not os.path.exists(image_path):
#             print(f"⚠️ 이미지 없음: {image_path}")
#             continue

#         image = Image.open(image_path).convert("RGB")
#         width, height = image.size

#         instances = frame_data["instances"]
#         sentences = frame_data["sentences"]
#         text_lines = [f"Video ID: {video_id_str} | Frame: {frame_id}"]

#         for inst in instances:
#             sentence = sentences[inst]
#             wrapped = textwrap.wrap(f"Instance {inst}: {sentence}", width=80)
#             text_lines.extend(wrapped)

#         line_height = 15
#         padding = 10
#         header_height = padding * 2 + line_height * len(text_lines)

#         new_image = Image.new("RGB", (width, height + header_height), color=(255, 255, 255))
#         new_image.paste(image, (0, header_height))
#         draw = ImageDraw.Draw(new_image)

#         y_offset = padding
#         for line in text_lines:
#             draw.text((padding, y_offset), line, font=font, fill=(0, 0, 0))
#             y_offset += line_height

#         display(new_image)

In [None]:
# samples_per_part = [600, 1100, 200, 300, 195]

# start_end_indices = []
# start_idx = 0

# for count in samples_per_part:
#     end_idx = start_idx + count
#     start_end_indices.append((start_idx, end_idx))
#     start_idx = end_idx

# for i, (start, end) in enumerate(start_end_indices, 1):
#     print(f"Part {i}: start_idx = {start}, end_idx = {end}, count = {end - start}")