In [None]:
import os
import numpy as np
import h5py
from PIL import Image
import matplotlib.pyplot as plt
import torch
import open_clip

### 0. Load Dataset

In [None]:
"""
    Optional: Download sample dataset from GDrive (7.07GB)
"""
%pip install --upgrade gdown

import os
import zipfile
import gdown

file_id = '1dmOHCXq7CvSoY1mEq0ISvKxJA_kmQQkG'
download_url = f'https://drive.google.com/uc?id={file_id}'

output_path = 'dataset/file.zip'
gdown.download(download_url, output_path, quiet=False)

with zipfile.ZipFile(output_path, 'r') as zip_ref:
    zip_ref.extractall('dataset')
os.remove(output_path)

In [None]:
file_names = sorted([os.path.join("dataset", f) for f in os.listdir("dataset") if ".h5" in f])
print("num file:", len(file_names))

### 1. Download CLIP Models

In [None]:
# Configs
device = torch.device("cuda:0")
batch_size = 128

In [None]:
# Download CLIP
model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-32", pretrained="laion2b_s34b_b79k")
model = model.to(device)
model = model.eval()

### 2. Hyperparameters

In [None]:
# Hyperparameters (for demonstrations recorded in 10 Hz)
window_size = 20
transition_thresh = 50
transition_thresh2 = 0.04
min_seg_size = round(window_size * 1.5)
crop_size_original = 256  # px in 1280x720 image

### 3. Measure Change Scores

In [None]:
def extract_gaze_info(
    demo_file: str,
    window_size: int,
    crop_size_original: int,
    batch_size: int,
    device: torch.device,
    model,
    preprocess,
) -> tuple[np.ndarray, np.ndarray]:
    with h5py.File(demo_file, "r") as demo:
        eps_steps = len(demo["left_img"])

        _, H, W, C = demo["left_img"].shape
        crop_size = int(crop_size_original * W / 1280)

        gazes = np.round(demo["gaze"]).astype(np.int64).reshape(-1, 2, 2)  # (N, 2, 2)
        gazes = np.clip(gazes, [0, 0], [W, H])

        gazes_median = np.array([np.median(gazes[max(0, i - window_size) : i + window_size + 1], axis=0) for i in range(len(gazes))]).astype(np.int64)  # (N, 2, 2)

        imgs = []
        for step in range(eps_steps):
            img = np.stack([demo["left_img"][step], demo["right_img"][step]])  # (2, H, W, C)
            img = np.ascontiguousarray(img[:, :, :, [2, 1, 0]])  # BGR2RGB

            pad_img = np.zeros((2, H + crop_size, W + crop_size, C), dtype=np.uint8)
            pad_img[:, crop_size // 2 : H + crop_size // 2, crop_size // 2 : W + crop_size // 2] = img

            gaze = gazes_median[step]

            gaze_img = []
            for lr in range(2):
                gaze_img.append(pad_img[lr, gaze[lr, 1] : gaze[lr, 1] + crop_size, gaze[lr, 0] : gaze[lr, 0] + crop_size])  # (crop_size, crop_size, C)
            gaze_img = np.stack(gaze_img)  # (2, crop_size, crop_size, C)
            # display(Image.fromarray(gaze_img[0]))

            imgs.append(gaze_img)
        imgs = np.stack(imgs)  # (N, 2, crop_size, crop_size, C)

    image_features = []
    for i in range((eps_steps + (batch_size // 2 - 1)) // (batch_size // 2)):
        image = imgs[i * batch_size // 2 : (i + 1) * batch_size // 2].reshape(-1, *imgs.shape[2:])  # (B * 2, H, W, C)
        image = torch.stack([preprocess(Image.fromarray(im)) for im in image])  # (B * 2, H, W, C)

        with torch.no_grad(), torch.cuda.amp.autocast():
            image = image.to(device)
            image_feature = model.encode_image(image, normalize=True)  # (B * 2, feature_dim)
        image_features.append(image_feature.detach().cpu().numpy().reshape(-1, 2, image_feature.shape[1]))  # (B, 2, feature_dim)
    image_features = np.concatenate(image_features)  # (N, 2, feature_dim)

    return gazes_median, image_features

In [None]:
def change_score(gazes_median: np.ndarray, image_features: np.ndarray) -> tuple[list, list]:
    score_gazes = []
    score_features = []
    for step in range(len(gazes_median) - 1):
        # Euclid distance of gaze
        gaze_diff = (gazes_median[step + 1] - gazes_median[step]).reshape(-1)  # (4,)

        # Cosine similarity of image_features
        image_feature_before = image_features[step]  # (2, feature_dim)
        image_feature_after = image_features[step + 1]  # (2, feature_dim)
        feature_similarity = -np.log((image_feature_before @ image_feature_after.T).diagonal() + 1 + 1e-6) + np.log(2)  # (2,)

        # Calculate scores
        score_gaze = np.linalg.norm(gaze_diff)
        score_feature = feature_similarity.mean(0)

        score_gazes.append(score_gaze)
        score_features.append(score_feature)

    score_gazes = [0] + score_gazes
    score_features = [0] + score_features

    return score_gazes, score_features


In [None]:
# Measure change scores of gaze information
data_score_gazes = []
data_score_features = []
data_eps_steps = []
for eps_idx, file_name in enumerate(file_names):
    # Extract gaze information from demonstration file
    gazes_median, image_features = extract_gaze_info(file_name, window_size, crop_size_original, batch_size, device, model, preprocess)
    print(f"Demo info: {eps_idx} [file_name={file_name}, eps_step={len(gazes_median)}]")

    # Calculate change scores
    score_gazes, score_features = change_score(gazes_median, image_features)

    # Data for task decomposition
    data_score_gazes.append(score_gazes)
    data_score_features.append(score_features)
    data_eps_steps.append(len(gazes_median))
data_eps_steps = np.cumsum([0] + data_eps_steps)  # (N_eps + 1,)
print(data_eps_steps)

### 4. Detect Gaze Transitions by Change Scores

In [None]:
def detect_transition(
    score_gazes: list,
    score_features: list,
    thresh: float,
    thresh2: float,
    min_seg_size: int,
) -> list:
    eps_steps = len(score_gazes)

    # Detecting gaze transition
    seg_steps = []
    for step in range(eps_steps - 1):
        score_gaze = score_gazes[step]
        score_feature = score_features[step]
        if score_gaze > thresh and score_feature > thresh2:
            seg_steps.append(step)

    seg_steps = [0] + seg_steps + [eps_steps]

    # Thinning seg points (Delete segments that are too small)
    total_scores = thresh * np.array(score_gazes) + thresh2 * np.array(score_features)
    seg_sizes = np.diff(seg_steps)
    while not np.all(seg_sizes >= min_seg_size):
        if len(seg_steps) < 4:
            break
        
        assert len(seg_sizes) > 1
        min_i = np.argmin(seg_sizes)
        if min_i == 0:
            seg_steps.pop(min_i + 1)
        elif min_i == len(seg_sizes) - 1:
            seg_steps.pop(min_i)
        else:
            if total_scores[seg_steps[min_i + 1]] < total_scores[seg_steps[min_i]]:
                seg_steps.pop(min_i + 1)
            else:
                seg_steps.pop(min_i)
        seg_sizes = np.diff(seg_steps)

    return seg_steps


In [None]:
def visualize_detection(
    file_name: str,
    score_gazes: list,
    score_features: list,
    thresh: float,
    thresh2: float,
    seg_steps: list,
    crop_size_original: int,
):
    plt.plot(score_gazes, label="s_pos", c="sandybrown", lw=2)
    plt.plot(np.array(score_features) * thresh / thresh2, label="s_feat", c="cornflowerblue", lw=1)
    plt.hlines([thresh], 0, len(score_gazes), colors="black", lw=2)
    plt.vlines(np.array(seg_steps[1:-1]), 0, 30, colors="red", alpha=0.7, lw=3)
    plt.ylim(0, 150.0)
    plt.legend()
    plt.show()

    with h5py.File(file_name, "r") as demo:
        for step in seg_steps[1:-1]:
            NN = 9
            imgs = np.stack([demo["left_img"][step - NN : step + NN + 1], demo["right_img"][step - NN : step + NN + 1]], axis=1)  # (N, 2, H, W, C)
            imgs = np.ascontiguousarray(imgs[..., [2, 1, 0]])  # BGR2RGB
            gazes = np.round(demo["gaze"][step - NN : step + NN + 1]).astype(np.int64).reshape(-1, 2, 2)  # (N, 2, 2)

            N, _, H, W, C = imgs.shape
            crop_size = int(crop_size_original * W / 1280)

            pad_imgs = np.zeros((N, 2, H + crop_size, W + crop_size, C), dtype=np.uint8)
            pad_imgs[:, :, crop_size // 2 : H + crop_size // 2, crop_size // 2 : W + crop_size // 2] = imgs

            gazes = np.clip(gazes, [0, 0], [W, H])

            gaze_imgs = []
            for n in range(N):
                for lr in range(2):
                    gaze_imgs.append(pad_imgs[n, lr, gazes[n, lr, 1] : gazes[n, lr, 1] + crop_size, gazes[n, lr, 0] : gazes[n, lr, 0] + crop_size])  # (crop_size, crop_size, C)
            gaze_imgs = np.stack(gaze_imgs).reshape(N, 2, crop_size, crop_size, C)  # (N, 2, crop_size, crop_size, C)

            print(f"step: {step}, scores: {thresh * np.array(score_gazes[step]) + thresh2 * np.array(score_features[step])} ({score_gazes[step]} + {score_features[step]})")
            display(Image.fromarray(np.concatenate(gaze_imgs[:, 0], axis=1)).resize((64 * N, 64)))
            # display(Image.fromarray(np.concatenate(gaze_imgs[:, 1], axis=1)).resize((64 * N, 64)))

In [None]:
# Detect gaze transition using change scores
data_seg_steps = []
for eps_idx, file_name in enumerate(file_names):
    if eps_idx == len(data_score_gazes):
        break

    print("===================================================================")
    print(f"Demo info: {eps_idx} [file_name={file_name}, eps_step={len(data_score_gazes[eps_idx])}]")

    seg_steps = detect_transition(
        data_score_gazes[eps_idx],
        data_score_features[eps_idx],
        transition_thresh,
        transition_thresh2,
        min_seg_size,
    )
    data_seg_steps.append(seg_steps)

    print("seg_steps:", seg_steps)

    visualize_detection(
        file_name,
        data_score_gazes[eps_idx],
        data_score_features[eps_idx],
        transition_thresh,
        transition_thresh2,
        seg_steps,
        crop_size_original,
    )

### 5. Mode of Segment Counts

In [None]:
num_segs, num_seg_count = np.unique([len(seg_steps) for seg_steps in data_seg_steps[: eps_idx + 1]], return_counts=True)
seg_num = num_segs[np.argmax(num_seg_count)] - 1  # remove 0 and eps_steps
seg_nums = np.array([seg_num for _ in file_names])

print(f"Majority of sub-task counts: {seg_num} (num_segs: {num_segs - 1}, count: {num_seg_count})")
print("seg_nums:\n", seg_nums)

### 6. Refine Detection Results

In [None]:
failure_count = 0
delta = 0.01
max_iter = 500
for eps_idx, file_name in enumerate(file_names):
    if eps_idx == len(data_score_gazes):
        break

    new_thresh = transition_thresh
    new_thresh2 = transition_thresh2

    seg_steps = data_seg_steps[eps_idx].copy()

    # Shortage than seg_num: decrease thresh little by little until seg_num is reached (both thresh are varied by a uniform small percentage).
    if len(seg_steps) - 1 < seg_nums[eps_idx]:
        for _ in range(max_iter):
            new_thresh = new_thresh - delta * transition_thresh
            new_thresh2 = new_thresh2 - delta * transition_thresh2

            seg_steps = detect_transition(
                data_score_gazes[eps_idx],
                data_score_features[eps_idx],
                new_thresh,
                new_thresh2,
                min_seg_size,
            )

            if len(seg_steps) - 1 == seg_nums[eps_idx]:
                break

    # More than seg_num: increase thresh little by little until seg_num is reached
    elif len(seg_steps) - 1 > seg_nums[eps_idx]:
        for _ in range(max_iter):
            new_thresh = new_thresh + delta * transition_thresh
            new_thresh2 = new_thresh2 + delta * transition_thresh2

            seg_steps = detect_transition(
                data_score_gazes[eps_idx],
                data_score_features[eps_idx],
                new_thresh,
                new_thresh2,
                min_seg_size,
            )

            if len(seg_steps) - 1 == seg_nums[eps_idx]:
                break

    else:
        continue

    print("===================================================================")
    print(f"Demo info: {eps_idx} [file_name={file_name}, eps_step={len(data_score_gazes[eps_idx])}]")
    if len(seg_steps) - 1 == seg_nums[eps_idx]:
        print(f"Refinement result: {data_seg_steps[eps_idx]} --> {seg_steps}")
    else:
        print(f"[Failed to refine segmentation: {seg_steps} (init={data_seg_steps[eps_idx]})]")
        failure_count += 1

    data_seg_steps[eps_idx] = seg_steps

    visualize_detection(
        file_name,
        data_score_gazes[eps_idx],
        data_score_features[eps_idx],
        transition_thresh,
        transition_thresh2,
        seg_steps,
        crop_size_original,
    )

print("num failure:", failure_count)

### 7. Save Segmentation Results

In [None]:
# Save results
invalid_episodes = []
for eps_idx, file_name in enumerate(file_names):
    with h5py.File(file_name, "r+") as demo:
        if "change_steps" in demo:
            del demo["change_steps"]

        if os.path.basename(file_name) not in invalid_episodes:
            if len(data_seg_steps[eps_idx]) - 1 != seg_nums[eps_idx]:
                invalid_episodes.append(file_name)
                continue
            print(f"episode {eps_idx}: {data_seg_steps[eps_idx]}")
            demo.create_dataset("change_steps", data=np.array(data_seg_steps[eps_idx]))

print("invalid episodes:", invalid_episodes)
print("successful episodes:", len(file_names) - len(invalid_episodes))

### 8. Check Saved Results

In [None]:
for eps_idx, file_name in enumerate(file_names):
    with h5py.File(file_name, "r") as demo:
        eps_steps = len(demo["left_img"])

        print("===================================================================")
        print(f"Demo info: {eps_idx} [file_name={file_name}, eps_step={eps_steps}]")

        if "change_steps" in demo:
            change_steps = demo["change_steps"]
        else:
            continue

        # Visualize each sub-task
        init_step = change_steps[0]
        for subtask_idx, change_step in enumerate(change_steps[1:]):
            print(f"[Sub-task {subtask_idx + 1}]")

            NN = 4  # FIXED
            interval = (change_step - init_step) // NN
            imgs = np.stack([demo["left_img"][[init_step, init_step + interval, init_step + 2 * interval, change_step - 1]], demo["right_img"][[init_step, init_step + interval, init_step + 2 * interval, change_step - 1]]], axis=1)  # (N, 2, H, W, C)
            imgs = np.ascontiguousarray(imgs[..., [2, 1, 0]])
            gazes = np.round(demo["gaze"][[init_step, init_step + interval, init_step + 2 * interval, change_step - 1]]).astype(np.int64).reshape(-1, 2, 2)  # (N, 2, 2)

            N, _, H, W, C = imgs.shape
            crop_size = int(crop_size_original * W / 1280)

            pad_imgs = np.zeros((N, 2, H + crop_size, W + crop_size, C), dtype=np.uint8)
            pad_imgs[:, :, crop_size // 2 : H + crop_size // 2, crop_size // 2 : W + crop_size // 2] = imgs

            gazes = np.clip(gazes, [0, 0], [W, H])

            masked_imgs = imgs * 0.5
            for n in range(N):
                for lr in range(2):
                    masked_imgs[n, lr, gazes[n, lr, 1] - crop_size // 2 : gazes[n, lr, 1] + crop_size // 2 + 1, gazes[n, lr, 0] - crop_size // 2 : gazes[n, lr, 0] + crop_size // 2 + 1] = imgs[
                        n, lr, gazes[n, lr, 1] - crop_size // 2 : gazes[n, lr, 1] + crop_size // 2 + 1, gazes[n, lr, 0] - crop_size // 2 : gazes[n, lr, 0] + crop_size // 2 + 1
                    ]
            masked_imgs = masked_imgs.astype(np.uint8)

            display(Image.fromarray(np.concatenate(masked_imgs[:, 0], axis=1)).resize((320 * N, 180)))

            init_step = change_step