In [1]:
%cd /home/donny/video_classification/

/home/donny/video_classification


In [2]:
import os
import logging
import json
import torch
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
import cv2
import tqdm
from sklearn.metrics import confusion_matrix, classification_report, mean_absolute_error

from action_recognition.experiment.config import WandbConfig
from action_recognition.evaluate.evaluate_video import start_inferencing, fn2outfn, load_fine_annot
from action_recognition.datasets.video_annotation_dataset import read_annotations

logging.basicConfig(level=logging.INFO)

In [3]:
video_paths = [
    ('data/mouse_video/20190814/V_20190814_134337_OC0_c0.mp4', '2rde0ddy', 'a3gryeq6', '3t739td2', '114nzofm', '25v2nlrs', '3crs2xcr'),  # 0
    # ('data/mouse_video/20190906/V_20190906_123023_OC0_c0.mp4', 'l2mzft8q', '6i6zza7c', '2p49jz7y', '24r2keuu', '18yttzbt', '21y4c15h'),  # 1
    # ('data/mouse_video/20190906/V_20190906_123023_OC0_c1.mp4', '1ist6nhz', '2mm706ka', '1p7q2enf', 'pbp3db4b', '3lrrwzr4', '2oumot5a'),  # 2
    # ('data/mouse_video/20190917/V_20190917_124903_OC0_c0.mp4', 'rw815hps', '138hzuu4', '1a90abbt', '2hrki5cj', '3oas0xfo', '3n3db90n'),  # 3
    # ('data/mouse_video/20190917/V_20190917_124903_OC0_c2.mp4', 'ovoyjcx6', '2jsuyqrf', '21n0m1c2', '1zv4m9y6', '20udas4r', '1jinxthk'),  # 4
    # ('data/mouse_video/20190917/V_20190917_124903_OC0_c3.mp4', 'pog9nxuz', '6pd6lugv', 'atoycpqm', '2pdzvfpb', '2lmxand7', '24j30x41'),  # 5
    # ('data/mouse_video/20190919/V_20190919_134212_OC0_c0.mp4', 'v502fsm2', '25303eo8', '15poo2gb', '39xx0mn0', '2fzaldtp', '3vluaurl'),  # 6
    # ('data/mouse_video/20190919/V_20190919_134212_OC0_c1.mp4', '2mnxfyf9', 'zcqqgkat', '3fn7a708', '1vse5sl5', '3d9g5221', '1l3mel8c'),  # 7
    # ('data/mouse_video/20190919/V_20190919_134212_OC0_c2.mp4', '37iiuecu', '2jzmcfoh', '3aaz2qhl', '25yd68le', 'xm3vgehq', '15wek0v2'),  # 8
    ('data/mouse_video/20190919/V_20190919_134212_OC0_c3.mp4', '104u3l9o', '3edvhvwf', '1g5728fb', '2yzzmjyj', '16rxlt0t', '1upg15gr'),  # 9
]  # I3D, balancevid, mixclip, mixclip_balancevid, 10vid, logweight_balancevid
# model_id = '10dwpg3x'
# model_artifact_name = f"run_{model_id}_model"

fns = [[p] + [fn2outfn(p, f"run_{mid}_model") for mid in model_id] for p, *model_id in video_paths]


In [6]:
# o1, o2, gts, weight = [], [], [], []
# all_d, all_gt = [], []
for fn, *outfns in fns:
    data = []
    for outfn in [outfns[4]]:
        d = torch.stack(torch.load(outfn))[:, :2]
        d = torch.nn.functional.softmax(d, dim=1)  # [L, 2]
        d = torch.nn.functional.pad(d.T.view(2, 1, -1), (4, 0), 'constant', 0)
        d = torch.nn.functional.avg_pool1d(d, 5, stride=1).squeeze(1).T
        d = (d >= 0.8).float()  # [L, 2]
        data.append(torch.cat([torch.zeros((15, 1)), d[:, :1]]).float())
    print("\nlen:", data[0].shape[0], "name:", fn)

    # anno = load_annot(fn)
    anno = load_fine_annot(fn)
    gt = torch.zeros(data[0].shape[0], 1)
    for s, e in anno:
        gt[s:e] = 1
    
    cap = cv2.VideoCapture(fn)
    total = cap.get(cv2.CAP_PROP_FRAME_COUNT)
    pbar = tqdm.tqdm(total=int(total))
    
    size = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    out = cv2.VideoWriter(os.path.basename(fn), int(cap.get(cv2.CAP_PROP_FOURCC)), cap.get(cv2.CAP_PROP_FPS), size)

    while cap.isOpened():
        ret, frame = cap.read()
        if ret:
            md_pred = 'V' if data[0][pbar.n].item() == 1 else 'X'
            gt_pred = 'V' if gt[pbar.n].item() == 1 else 'X'
            text = f'model: {md_pred}'
            cv2.putText(frame, text, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (0, 255, 255), 1, cv2.LINE_AA)
            text = f'truth: {gt_pred}'
            cv2.putText(frame, text, (10, 60), cv2.FONT_HERSHEY_PLAIN, 2, (0, 255, 255), 1, cv2.LINE_AA)
            out.write(frame)
            pbar.update(1)
        else:
            break

    pbar.close()
    cap.release()
    out.release()   


len: 24566 name: data/mouse_video/20190814/V_20190814_134337_OC0_c0.mp4
100%|██████████| 24566/24566 [01:30<00:00, 270.49it/s]

len: 58961 name: data/mouse_video/20190919/V_20190919_134212_OC0_c3.mp4
100%|██████████| 58961/58961 [03:13<00:00, 304.03it/s]


In [5]:

pbar.close()
cap.release()
out.release()   

