In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from momaapi import MOMA
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

In [2]:
moma = MOMA("/data/dir_moma")
sbert = SentenceTransformer("all-MiniLM-L6-v2")

In [3]:
vid2seqembs = {}
ids_act = moma.get_ids_act()

for act in tqdm(moma.get_anns_act(ids_act=ids_act), desc="[MOMA] preprocessing"):
    seq = []
    for sact in moma.get_anns_sact(ids_sact=act.ids_sact):
        seq.append(sact.cname)
    with torch.no_grad():
        seq_emb = sbert.encode(seq)
    vid2seqembs[act.id] = seq_emb

[MOMA] preprocessing: 100%|██████████| 1412/1412 [00:56<00:00, 25.19it/s]


In [4]:
def compute_dtw_score(x, y, eps, w):
    nx = x / np.linalg.norm(x, axis=-1, keepdims=True)
    ny = y / np.linalg.norm(y, axis=-1, keepdims=True)
    z = np.matmul(nx, ny.T)

    m, n = z.shape[0], z.shape[1]
    R = np.ones((m+1, n+1))
    R[0,:], R[:,0] = -np.inf, -np.inf
    R[0,0] = 0

    for i in range(1, m+1):
        for j in range(1, n+1):
            # if abs(i - j) > w:
            #     continue
            r0 = R[i-1, j-1] 
            r1 = R[i-1, j] 
            r2 = R[i, j-1] 
            R[i, j] = max(r0, r1, r2) + z[i-1, j-1]

    # backtracking
    i, j, size = m, n, 0
    path = []
    while i >= 1 and j >= 1:
        size += 1
        path.append((i, j))
        r0 = R[i-1, j-1] 
        r1 = R[i-1, j] 
        r2 = R[i, j-1] 
        rmax = max(r0, r1, r2)

        if rmax == r0:
            i, j = i - 1, j - 1
        elif rmax == r1:
            i = i - 1
        elif rmax == r2:
            j = j - 1
        else:
            raise ValueError
        
    # print(f"R[m, n]: {R[m, n]} size: {size} score: {R[m,n] / size}")
        
    return R, R[m, n] / size, path

In [5]:
qvid = "BJGywz0wWKg"
vid1 = "ICN26rD0i6Q"
vid2 = "pIvkh4QARX4"

query_video = vid2seqembs[qvid]
video1 = vid2seqembs[vid1]
video2 = vid2seqembs[vid2]

R1, rel1, path1 = compute_dtw_score(query_video, video1, 0, 0)
R2, rel2, path2 = compute_dtw_score(query_video, video2, 0, 0)

print(f"rel1: {rel1} rel2: {rel2}")

plt.matshow(R1)
plt.plot([p[1] for p in path1], [p[0] for p in path1], "w")
plt.savefig("rel1.png")
plt.close()

plt.matshow(R2)
plt.plot([p[1] for p in path2], [p[0] for p in path2], "w")
plt.savefig("rel2.png")
plt.close()

rel1: 0.7128006815910339 rel2: 1.0000000541860408
