In [2]:
import torch
from s3dg import S3D

s3d_model = S3D('s3d_dict.npy', 512)
s3d_model.load_state_dict(torch.load('s3d_howto100m.pth'))

s3d_model = s3d_model.eval()

In [3]:
import cv2

def video_to_tensor(video_path):
    cap = cv2.VideoCapture(video_path)

    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.resize(frame, (256, 256))
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = torch.tensor(frame, dtype=torch.float32).permute(2, 0, 1) / 255.0
        frames.append(frame)

    cap.release()
    video = torch.stack(frames).transpose(1, 0).unsqueeze(0)
    print(video.shape)
    return video


In [4]:
import os

demodir = "C:/Users/lee/Desktop/ml/flybyml/module/s3d/demonstration" 
demo_embeddings = []
for i, path in enumerate(os.listdir(demodir)):
    demopath = os.path.join(demodir, f"level_off_4s_{i+1}.mp4")
    demo = video_to_tensor(demopath)
    demo_embeddings.append(s3d_model(demo)['video_embedding'])

failpath = "C:/Users/lee/Desktop/ml/flybyml/module/s3d/failure/level_off_4s.mp4"
failure_embedding = s3d_model(video_to_tensor(failpath))['video_embedding']

torch.Size([1, 3, 32, 256, 256])
torch.Size([1, 3, 32, 256, 256])
torch.Size([1, 3, 32, 256, 256])
torch.Size([1, 3, 32, 256, 256])


In [5]:
text = ['cockpit leveling out the plane']
text_embedding = s3d_model.text_module(text)['text_embedding']

# compute all the pairwise similarity scores between video and text
for i, t in enumerate(text):
    print(f'"{t}" simularity with level-off demo')
    for j in range(3):
        similarity_matrix = torch.matmul(text_embedding[i], demo_embeddings[j].t())
        print(similarity_matrix.item())
    print("-----------")

"cockpit leveling out the plane" simularity with level-off demo
4.106225490570068
5.893315315246582
6.238439083099365
-----------


In [6]:
for i, t in enumerate(text):
    print(f'"{t}" simularity with level-off failure video')
    similarity_matrix = torch.matmul(text_embedding[i], failure_embedding.t())
    print(similarity_matrix.item())
    print("-----------")


"cockpit leveling out the plane" simularity with level-off failure video
2.7186119556427
-----------


In [77]:
btw_demo = torch.matmul(demo_embeddings[0], demo_embeddings[1].t())
print(btw_demo.item())

0.14361393451690674


In [75]:
with_failure = torch.matmul(demo_embeddings[0], failure_embedding.t())
print(with_failure.item())

0.17633603513240814
