In [1]:
%cd /home/yokoyama/research
from types import SimpleNamespace
import sys
import os
from glob import glob

import cv2
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.manifold import TSNE
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

sys.path.append(".")
from modules.utils.video import Capture, Writer
from modules.pose import PoseDataHandler

/raid6/home/yokoyama/research


In [2]:
from submodules.i3d.pytorch_i3d import InceptionI3d
from torchvision.ops import roi_align

In [3]:
video_num = 1
cap = Capture(f"/raid6/home/yokoyama/datasets/dataset01/train/{video_num:02d}.mp4")
pose_data = PoseDataHandler.load(f"data/dataset01/train/{video_num:02d}", ["bbox"])

In [46]:
def calc_opticalflow(frames):
    prev_frame = frames[0]
    prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY)
    h, w = prev_frame.shape[:2]
    flows = [np.zeros((h, w, 2), dtype=np.float32)]
    for frame in tqdm(frames[1:], leave=False):
        next_gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        flow = cv2.calcOpticalFlowFarneback(prev_gray, next_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
        flows.append(flow)

    return np.array(flows)


def load_rgb_frames(cap, batch_num, start, length):
    batches_frame = []
    batches_flow = []
    for i in tqdm(range(batch_num)):
        frames_raw = []
        frames = []
        start = start + i
        cap.set_pos_frame_count(start)
        for j in tqdm(range(start, start + length), leave=False):
            img = cap.read()[1]
            frames_raw.append(img.copy())
            img = (img / 255.) * 2 - 1
            frames.append(img)

        batches_frame.append(frames)

        flows = calc_opticalflow(frames_raw)
        batches_flow.append(flows)
    return np.array(batches_frame, dtype=np.float32), np.array(batches_flow, dtype=np.float32)


batch_num = 1
start_frame_num = 1410
frame_length = 30
frames, flows = load_rgb_frames(cap, batch_num, start_frame_num, frame_length)

100%|██████████| 1/1 [00:11<00:00, 11.16s/it]


In [63]:
in_channels = 3
if in_channels == 3:
    tensor = torch.Tensor(frames)
elif in_channels == 2:
    tensor = torch.Tensor(flows)
tensor = torch.permute(tensor, (0, 4, 1, 2, 3))
tensor.shape

torch.Size([1, 3, 30, 940, 1280])

In [64]:
model_path = "submodules/i3d/models/rgb_imagenet.pt"
# model_path = "submodules/i3d/models/flow_charades.pt"
i3d = InceptionI3d(in_channels=in_channels)
# i3d.replace_logits(157)
i3d.load_state_dict(torch.load(model_path))

# i3d = InceptionI3d(in_channels=in_channels, final_endpoint="Mixed_3b")
# i3d.build()

<All keys matched successfully>

In [65]:
# feature = i3d.extract_features(tensor)
x = tensor
for end_point in i3d.VALID_ENDPOINTS:
    if end_point in i3d.end_points:
        x = i3d._modules[end_point](x)
    if end_point == "Mixed_3b":
        break
feature = x
feature.shape

torch.Size([1, 256, 15, 118, 160])

torch.Size([2, 256, 15, 118, 160]) 3b  
torch.Size([2, 480, 15, 118, 160]) 3c  
torch.Size([1, 480, 8, 59, 80]) 4a  
torch.Size([1, 512, 8, 59, 80]) 4b  
torch.Size([1, 832, 8, 59, 80]) 4f

In [66]:
bboxs_all = []
for i in range(frame_length):
    pose_data_frame = [
        data for data in pose_data
        if data["frame"] == start_frame_num + i
    ]

    bboxs = [np.array(data["bbox"]).reshape(2, 2) for data in pose_data_frame]

    bboxs = torch.Tensor(np.array(bboxs))
    bboxs /= torch.Tensor((w, h))
    bboxs *= torch.Tensor((fx, fy))

    bboxs_all.append(torch.Tensor(bboxs.reshape(-1, 4)))

feature_aligned = roi_align(feature[0], bboxs_all, 3, spatial_scale=fx / w, aligned=False)
feature_aligned = feature_aligned.detach().numpy()

pids = [data["id"] for data in pose_data_frame]
loss_dict = {}
for i in range(len(pids) - 1):
    for j in range(i + 1, len(pids)):
        f1 = feature_aligned[i]
        f2 = feature_aligned[j]
        loss = (np.square(f1 - f2)).mean()

        key = f"{pids[i]}-{pids[j]}"
        loss_dict[key] = loss

In [67]:
loss_dict.keys()

dict_keys(['2-10', '2-6', '2-11', '2-1', '2-4', '10-6', '10-11', '10-1', '10-4', '6-11', '6-1', '6-4', '11-1', '11-4', '1-4'])

In [68]:
for key, loss in loss_dict.items():
    print(key, loss)

2-10 0.37568292
2-6 0.28895658
2-11 0.25893262
2-1 0.31032327
2-4 0.31060413
10-6 0.09601549
10-11 0.09754126
10-1 0.0953755
10-4 0.09537507
6-11 0.0012808882
6-1 0.00089278305
6-4 0.000913571
11-1 0.0031828596
11-4 0.0032410931
1-4 1.4524527e-06
