In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from skimage.io import imread
from skimage.color import rgb2gray
from skimage.util import img_as_ubyte
from scipy.signal import coherence
import warnings

In [None]:
BLOCK_SIZE = 32

In [None]:
def compute_tc(a, b, fs=1.0, nperseg=256):
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', category=RuntimeWarning)
        f, Cxy = coherence(a, b, fs=fs, nperseg=nperseg)
    Cxy = Cxy[~np.isnan(Cxy) & ~np.isinf(Cxy)]
    return np.mean(Cxy).item() if len(Cxy) > 0 else 0.0

def extract_tc_features(ref_frame, tar_frame, block_size=32, fs=1.0, nperseg=256):
    h, w = ref_frame.shape
    features = []

    if block_size == 0:
        tc_score = compute_tc(ref_frame.flatten(), tar_frame.flatten(), fs=fs, nperseg=nperseg)
        features.append({'block_y': 0, 'block_x': 0, 'tc': tc_score})
        return np.array([[tc_score]])

    for y in range(0, h - block_size + 1, block_size):
        for x in range(0, w - block_size + 1, block_size):
            ref_patch = ref_frame[y:y+block_size, x:x+block_size].flatten()
            tar_patch = tar_frame[y:y+block_size, x:x+block_size].flatten()
            tc_score = compute_tc(ref_patch, tar_patch, fs=fs, nperseg=nperseg)

            print(f"TC at block (y={y}, x={x}) = {tc_score:.4f}")
            features.append(tc_score)

    return np.array(features).reshape(h // block_size, w // block_size)

In [None]:
frame_dir = Path("dataset/DAVIS/curling")
frame_paths = sorted(frame_dir.glob("*.jpg"))
gray_frames = []

for path in frame_paths:
    image = imread(path)
    gray = rgb2gray(image) if image.ndim == 3 else image
    gray = img_as_ubyte(gray)
    gray_frames.append(gray)

ref_frame, tar_frame = gray_frames[0], gray_frames[1]
tc_map = extract_tc_features(ref_frame, tar_frame, block_size=BLOCK_SIZE)

plt.figure(figsize=(6, 5))
plt.imshow(tc_map, cmap='hot', interpolation='nearest')
plt.title(f"TC Heatmap (block_size={BLOCK_SIZE})")
plt.colorbar()
plt.axis("off")
plt.tight_layout()
plt.show()