In [1]:
#Part-1 Video/image feature extraction

In [2]:
# Modified to process a list of videos
"""Extract features for videos using pre-trained networks"""
from feature_extract.configs.custom_config import load_config
from slowfast.utils.misc import launch_job
from slowfast.utils.parser import parse_args
import numpy as np
import torch
import os
import time
from tqdm import tqdm
import av
from moviepy.video.io.VideoFileClip import VideoFileClip

import slowfast.utils.checkpoint as cu
import slowfast.utils.distributed as du
import slowfast.utils.logging as logging
import slowfast.utils.misc as misc

from feature_extract.models import build_model
from feature_extract.datasets.extract_dataset import VideoSet
import copy
logger = logging.get_logger(__name__)




In [4]:
import os
import random
from io import BytesIO
import torch
import numpy as np
import cv2
from slowfast.datasets.utils import pack_pathway_output
from feature_extract.configs.custom_config import load_config
from slowfast.utils.parser import parse_args
from matplotlib import pyplot as plt
from collections import deque

import argparse
import sys





In [5]:
#Feature extraction configs
def parse_args():
    """
    Parse the following arguments for a default parser for PySlowFast users.
    Args:
        shard_id (int): shard id for the current machine. Starts from 0 to
            num_shards - 1. If single machine is used, then set shard id to 0.
        num_shards (int): number of shards using by the job.
        init_method (str): initialization method to launch the job with multiple
            devices. Options includes TCP or shared file-system for
            initialization. details can be find in
            https://pytorch.org/docs/stable/distributed.html#tcp-initialization
        cfg (str): path to the config file.
        opts (argument): provide addtional options from the command line, it
            overwrites the config loaded from file.
    """
    parser = argparse.ArgumentParser(
        description="Provide SlowFast video training and testing pipeline."
    )
    parser.add_argument(
        "--shard_id",
        help="The shard id of current node, Starts from 0 to num_shards - 1",
        default=0,
        type=int,
    )
    parser.add_argument(
        "--num_shards",
        help="Number of shards using by the job",
        default=1,
        type=int,
    )
    parser.add_argument(
        "--init_method",
        help="Initialization method, includes TCP or shared file-system",
        default="tcp://localhost:9999",
        type=str,
    )
    parser.add_argument(
        "--cfg",
        dest="cfg_files",
        help="Path to the config files",
        default=["/data/disk/LUO/slowfast/feature_extract/configs/SLOWFAST_8x8_R50_1031.yaml"],
        nargs="+",
    )
    parser.add_argument(
        "--opts",
        help="See slowfast/config/defaults.py for all options",
        default=None,
        nargs=argparse.REMAINDER,
    )
    parser.add_argument(
        "--f",
        help="for jupyternotebook to run",
        default=None,
        nargs=argparse.REMAINDER,
    )
    if len(sys.argv) == 1:
        parser.print_help()
    return parser.parse_args()

In [6]:
#创建一个steram output feature的方法

In [10]:
#Feature extraction image process

def pre_process_frame(arr):
        """
        Pre process an array
        Args:
            arr (ndarray): an array of frames of shape T x H x W x C 
        Returns:
            arr (tensor): a normalized torch tensor of shape C x T x H x W 
        """
        arr = torch.from_numpy(arr).float()
        # Normalize the values
        arr = arr / 255.0
        #DATA.MEAN = [0.45, 0.45, 0.45]
        arr = arr - torch.tensor([0.45, 0.45, 0.45])
        #_C.DATA.STD = [0.225, 0.225, 0.225]
        arr = arr / torch.tensor([0.225, 0.225, 0.225])

        # T H W C -> C T H W.
        try:
            arr = arr.permute(3, 0, 1, 2)
        except Exception as e:
            print("length of the array is not T x H x W x C ")

        return arr

In [11]:
#model(input)的框架

In [12]:
def calculate_time_taken(start_time, end_time):
    hours = int((end_time - start_time) / 3600)
    minutes = int((end_time - start_time) / 60) - (hours * 60)
    seconds = int((end_time - start_time) % 60)
    return hours, minutes, seconds

#feature extraction slow fast inference
@torch.no_grad()
def perform_inference(inputs, model, cfg):
    """
    Perform mutli-view testing that samples a segment of frames from a video
    and extract features from a pre-trained model.
    Args:
        test_loader (loader): video testing loader.
        model (model): the pretrained video model to test.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    # Enable eval mode.
    model.eval()

    feat_arr = None

    # Transfer the data to the current GPU device.
    if isinstance(inputs, (list,)):
        for i in range(len(inputs)):
            inputs[i] = inputs[i].cuda(device=cfg.USED_GPU,non_blocking=True)
    else:
        inputs = inputs.cuda(device=cfg.USED_GPU,non_blocking=True)

    # Perform the forward pass.
    preds, feat = model(inputs)
    # Gather all the predictions across all the devices to perform ensemble.
    if cfg.NUM_GPUS > 1:
        preds, feat = du.all_gather([preds, feat])

    feat = feat.cpu().numpy()

    if feat_arr is None:
        feat_arr = feat
    else:
        feat_arr = np.concatenate((feat_arr, feat), axis=0)

    return feat_arr

In [13]:
#video feature extraction stream output
def test(model,cfg,inputs):
    """
    Perform multi-view testing/feature extraction on the pretrained video model.
    Args:
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """

    # Set random seed from configs.
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)

    # Setup logging format.
    logging.setup_logging(cfg.OUTPUT_DIR)

    # Print config.
    #logger.info("Test with config:")
    #logger.info(cfg)

    # Build the video model and print model statistics.
    
    #if du.is_master_proc() and cfg.LOG_MODEL_INFO:
        #misc.log_model_info(model, cfg, use_train_input=False)

    #changes here
    #checkpoint = torch.load("/data/disk/LUO/slowfast/feature_extract/checkpoints/checkpoint_epoch_00130.pyth")
    #model.load_state_dict(checkpoint['model_state'])

    #cu.load_test_checkpoint(cfg, model)


    start_time = time.time()


    # Perform multi-view test on the entire dataset.
    feat_arr = perform_inference(inputs, model, cfg)
    return feat_arr




In [14]:
# tridet evaluation configs

In [15]:
# python imports
import argparse
import os
import glob
import time
from pprint import pprint

# torch imports
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data

# our code
from libs.core import load_tridet_config
from libs.datasets import make_dataset, make_data_loader
from libs.modeling import make_meta_arch
from libs.utils import valid_one_epoch, ANETdetection, fix_random_seed


In [16]:
#create test dataset
import os
import json
import h5py
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset
from torch.nn import functional as F

from libs.datasets.datasets import register_dataset
from libs.datasets.data_utils import truncate_feats
from libs.utils import remove_duplicate_annotations


In [17]:
def _load_json(label_dict=None,split=['testing'],num_classes=19,json_file="/data/disk/LUO/test_only/r_tridet/TriDet/data/cataract/data_1102.json",default_fps=30):

        #create a csv
        #frame_start,frame_end,label,training_or_testing
        # load database and select the subset
        with open(json_file) as f:
            json_data = json.load(f)
        # if label_dict is not available, matching label(str) to label id (int)
        if label_dict is None:
            label_dict = {
            #define surgical phase name to phase label id
                "phase_{}".format(i+1):i for i in range(num_classes)
            }
            
        if split[0] == "training":
            split_name = "train"
        else:
            split_name = "test"
        # fill in the db (immutable afterwards)
        dict_db = tuple()
        for key, value in json_data.items():
            # key is the video id, v is the segement information
            # get fps if available

            if split_name not in value["annotation"][0]["subset"]:
                continue
            if default_fps is not None:
                fps = default_fps
            else:
                #fps=30
                fps = 1

            duration = []
            num_phase = len(value["annotation"])
            segments = np.zeros([num_phase, 2], dtype=np.float32)
            labels = np.zeros([num_phase, ], dtype=np.int64)
            for idx,phase in enumerate(value["annotation"]):
                duration.append(phase["time_till_now"])
                segments[idx][0] = phase["start"]
                segments[idx][1] = phase["end"]
                if num_classes == 1:
                    labels[idx] = 0
                else:
                    labels[idx] = phase["label"]
            last_time = value["last_time"]
            dict_db += ({'id': key.split(".")[0],
                         'fps': fps,
                         'duration': last_time,
                         'segments': segments,
                         'labels': labels
                         },)

        return dict_db, label_dict

In [18]:
def getitem(data_list,features, idx,num_frames,max_seq_len=1024,feat_stride=1,downsample_rate=1,force_upsampling=False,mirror=True):
    # directly return a (truncated) data point (so it is very fast!)
    # auto batching will be disabled in the subsequent dataloader
    # instead the model will need to decide how to batch / preporcess the data
    video_item = data_list[idx]

    # load features
    feats = features

    #shape is T x 2304
    # we support both fixed length features / variable length features
    if feat_stride > 0 and (not force_upsampling):
        # var length features
        feat_stride, num_frames = feat_stride, num_frames
        # only apply down sampling here
        if downsample_rate > 1:
            feats = feats[::downsample_rate, :]
            feat_stride = feat_stride * downsample_rate

    # T x C -> C x T
    if isinstance(feats, torch.Tensor):
        feats = feats.transpose(0, 1)
    else:
        feats = torch.from_numpy(np.ascontiguousarray(feats.transpose()))

    # convert time stamp (in second) into temporal feature grids
    # ok to have small negative values here
    if video_item['segments'] is not None:
        segments = torch.from_numpy(
            #(video_item['segments'] * video_item['fps'] - 0.5 * num_frames) / (feat_stride)
            video_item['segments']
        )
        labels = torch.from_numpy(video_item['labels'])
        # for activity net, we have a few videos with a bunch of missing frames
        # here is a quick fix for training
        segments, labels = None, None

    # return a data dict
    data_dict = {'video_id': video_item['id'],
                    'feats': feats,  # C x T
                    'segments': segments,  # N x 2
                    'labels': labels,  # N
                    'fps': 1,
                    'duration': num_frames,
                    'feat_stride': feat_stride,
                    'feat_num_frames': 32}


    return data_dict


In [19]:
class AverageMeter(object):
    """Computes and stores the average and current value.
    Used to compute dataset stats from mini-batches
    """

    def __init__(self):
        self.initialized = False
        self.val = None
        self.avg = None
        self.sum = None
        self.count = 0.0

    def initialize(self, val, n):
        self.val = val
        self.avg = val
        self.sum = val * n
        self.count = n
        self.initialized = True

    def update(self, val, n=1):
        if not self.initialized:
            self.initialize(val, n)
        else:
            self.add(val, n)

    def add(self, val, n):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [20]:
def valid_one_epoch(
        dict_db,
        model,
        curr_epoch=0,
        ext_score_file=None,
        evaluator=None,
        output_file=None,
        tb_writer=None,
        print_freq=20
):
    """Test the model on the validation set"""
    # either evaluate the results or save the results

    # set up meters
    batch_time = AverageMeter()
    # switch to evaluate mode
    model.eval()
    # dict for results (for our evaluation code)
    results = {
        'video-id': [],
        't-start': [],
        't-end': [],
        'label': [],
        'score': []
    }

    # loop over validation set
    start = time.time()
        # forward the model (wo. grad)
    with torch.no_grad():
        output = model(dict_db)

        # upack the results into ANet format
        num_vids = len(output)
        for vid_idx in range(num_vids):
            if output[vid_idx]['segments'].shape[0] > 0:
                results['video-id'].extend(
                    [output[vid_idx]['video_id']] *
                    output[vid_idx]['segments'].shape[0]
                )
                results['t-start'].append(output[vid_idx]['segments'][:, 0])
                results['t-end'].append(output[vid_idx]['segments'][:, 1])
                results['label'].append(output[vid_idx]['labels'])
                results['score'].append(output[vid_idx]['scores'])


    # gather all stats and evaluate
    results['t-start'] = torch.cat(results['t-start']).numpy()
    results['t-end'] = torch.cat(results['t-end']).numpy()
    results['label'] = torch.cat(results['label']).numpy()
    results['score'] = torch.cat(results['score']).numpy()

    return results


In [21]:
import os
import json
import pandas as pd
import numpy as np
from joblib import Parallel, delayed
from typing import List
from typing import Tuple
from typing import Dict
import datetime


In [22]:
def load_gt_seg_from_json(json_file, split=None, label='label', label_offset=0):
    # load json file
    with open(json_file, "r", encoding="utf8") as f:
        json_db = json.load(f)

    vids, starts, stops, labels = [], [], [], []
    if split == "training":
        split_name = "train"
    else:
        split_name = "test"
    for k, v in json_db.items():

        for segments in v["annotation"]:
            if split_name not in segments["subset"]:
                continue
            ants = segments
            vids.append(k.split(".")[0])
            starts.append(ants["start"])
            stops.append(ants["end"])
            labels.append(ants["label"])



        """# filter based on split
        if split_name not in k:
            continue
        # remove duplicated instances
        ants = v
        # video id
        vids += [k] * len(ants)
        # for each event, grab the start/end time and label
        for event in ants:
            starts += [float(event['start'])]
            stops += [float(event['end'])]
            if isinstance(event[label], (Tuple, List)):
                # offset the labels by label_offset
                label_id = 0
                for i, x in enumerate(event[label][::-1]):
                    label_id += label_offset ** i + int(x)
            else:
                # load label_id directly
                label_id = int(event[label])
            labels += [label_id]"""

    # move to pd dataframe
    gt_base = pd.DataFrame({
        'video-id': vids,
        't-start': starts,
        't-end': stops,
        'label': labels
    })

    return gt_base

In [23]:
def to_df(preds):
        """Evaluates a prediction file. For the detection task we measure the
        interpolated mean average precision to measure the performance of a
        method.
        preds can be (1) a pd.DataFrame; or (2) a json file where the data will be loaded;
        or (3) a python dict item with numpy arrays as the values
        """

        if isinstance(preds, Dict):
            # move to pd dataframe
            # did not check dtype here, can accept both numpy / pytorch tensors
            preds = pd.DataFrame({
                'video-id': preds['video-id'],
                't-start': preds['t-start'].tolist(),
                't-end': preds['t-end'].tolist(),
                'label': preds['label'].tolist(),
                'score': preds['score'].tolist()
            })



        return preds


In [24]:
def calculate_acc(pred_df,gt_df):
    with open("/data/disk/LUO/test_only/r_tridet/TriDet/data/cataract/data_1102.json","r") as f:
        gt_json = json.load(f)
    df_gt= pd.DataFrame(columns=["video_id","gt_labels"])
    for k,v in gt_json.items():
        if "train" in k:
            continue
        dur = v["last_time"]
        video_name = k.split(".")[0]
        gt_labels = [0 for i in range(int(np.round(dur)))]
        video_id = [k for i in range(int(np.round(dur)))]
        df = pd.DataFrame(columns=["video_id","gt_labels"])
        df["video_id"] = video_id
        df["gt_labels"] = gt_labels
        subset = gt_df[gt_df["video-id"]==video_name]
        
        for index in subset.index:
            start = int(np.round(subset.loc[index]["t-start"]))
            end = int(np.round(subset.loc[index]["t-end"]))
            label = subset.loc[index]["label"]
            for i in range(start,end+1):
                df.loc[i,"gt_labels"] = label+1
        
        df_gt = pd.concat([df_gt,df],ignore_index=True)    
    df_pred = pd.DataFrame(columns=["video_id","pred_labels"])


    for k,v in gt_json.items():
        if "train" in k:
            continue
        video_name = k.split(".")[0]
        dur = v["last_time"]
        gt_labels = [0 for i in range(int(np.round(dur)))]
        video_id = [k for i in range(int(np.round(dur)))]
        df = pd.DataFrame(columns=["video_id","pred_labels"])
        df["video_id"] = video_id
        df["pred_labels"] = gt_labels
        subset = pred_df[pred_df["video-id"]==video_name]
        
        for index in subset.index:
            start = int(np.round(subset.loc[index]["t-start"]))
            end = int(np.round(subset.loc[index]["t-end"]))
            label = subset.loc[index]["label"]
            for i in range(start,end+1):
                if i>=len(df):
                    continue
                df.loc[i,"pred_labels"] = label+1
        #print(df)
        df_pred = pd.concat([df_pred,df],ignore_index=True)


    return df_pred
    """
    df_all =pd.concat([df_gt,df_pred["pred_labels"]],axis=1)
    acc = accuracy_score(df_all["gt_labels"].to_list(),df_all["pred_labels"].to_list())
    f1 = f1_score(df_all["gt_labels"].to_list(),df_all["pred_labels"].to_list(),average="macro")
    recall = recall_score(df_all["gt_labels"].to_list(),df_all["pred_labels"].to_list(),average="macro")
    precision = precision_score(df_all["gt_labels"].to_list(),df_all["pred_labels"].to_list(),average="macro")
    df_pred.to_csv("df_pred.csv",index=False)
    df_gt.to_csv("df_gt.csv",index=False)
    print("accuracy under the threshold : ", acc,end="     ")
    print("f1 score under the threshold : ", f1,end="     ")
    print("precision score under the threshold : ", precision,end="     ")
    print("recall score under the threshold : ", recall)
    """

In [25]:
def get_middle_label(pred_df,mid_index,threshold=0.22):
    df = pred_df[pred_df["score"]>threshold]
    label = 0
    score = threshold
    s = mid_index[0]
    e = mid_index[1]
    #filter df
    if len(df) == 0:
        for i in range(len(pred_df)):
            if pred_df.iloc[i]["t-start"]<=e and pred_df.iloc[i]["t-end"]>=s:
                label = 0
                score = pred_df.iloc[i]["score"]
                break
    for i in range(len(df)):
        if df.iloc[i]["t-start"]<=e and df.iloc[i]["t-end"]>=s:
            label = df.iloc[i]["label"]
            score = df.iloc[i]["score"]
            break
    return label,score


In [26]:
"""def get_middle_label(pred_df,mid_index,threshold=0.22):
    df = pred_df[pred_df["score"]>threshold]
    label = 0
    s = mid_index[0]
    e = mid_index[1]
    #filter df
    for i in range(len(df)):
        if mid_index>=df.iloc[i]["t-start"] and mid_index <= df.iloc[i]["t-end"]:
            label = df.iloc[i]["label"]
            break
    return label"""

'def get_middle_label(pred_df,mid_index,threshold=0.22):\n    df = pred_df[pred_df["score"]>threshold]\n    label = 0\n    s = mid_index[0]\n    e = mid_index[1]\n    #filter df\n    for i in range(len(df)):\n        if mid_index>=df.iloc[i]["t-start"] and mid_index <= df.iloc[i]["t-end"]:\n            label = df.iloc[i]["label"]\n            break\n    return label'

In [36]:
def mirror_feature(inputs,fill_n_times=3):
    #return the mirrored inputs, e.g (2304,100) -> (2304,200)
    feat = inputs["feats"]
    feat_flipped = torch.flip(feat, dims=[1])
    if fill_n_times>0:
        feat_fill = feat[:,-1].unsqueeze(-1).repeat(1,fill_n_times)
        feat_mirrored = torch.concat((feat, feat_fill,feat_flipped), dim=1)
        if feat_mirrored.size(1)>1024:
            center = feat_mirrored.size(1)//2
            start_index = center-512
            end_index = center+512
            feat_mirrored = feat_mirrored[:, start_index:end_index]
            inputs["feats"] = feat_mirrored
            inputs['duration'] = 1024
            return inputs
        inputs["feats"] = feat_mirrored
        inputs['duration'] = inputs['duration']*2+fill_n_times
    else:
        feat_mirrored = torch.concat((feat, feat_fill,feat_flipped), dim=1)
        inputs["feats"] = feat_mirrored
        inputs['duration'] = inputs['duration']*2
    return inputs



In [38]:
def fill_feature(inputs,fill_n_times=3):
        #return the mirrored inputs, e.g (2304,100) -> (2304,200)
    input_copy = copy.deepcopy(inputs)
    feat = input_copy["feats"]
    if fill_n_times>0:
        feat_fill = feat[:,-1].unsqueeze(-1).repeat(1,fill_n_times)
        feat_filled = torch.concat((feat, feat_fill), dim=1)
        if feat_filled.size(1)>1024:
            center = feat_filled.size(1)//2
            start_index = center-512
            end_index = center+512
            feat_filled = feat_filled[:, start_index:end_index]
            input_copy["feats"] = feat_filled
            input_copy['duration'] = 1024
            return input_copy
        input_copy["feats"] = feat_filled
        input_copy['duration'] = input_copy['duration']+fill_n_times
    return input_copy  

In [None]:
def smooth(length=10,)

In [30]:
#Feature extraction configs
args = parse_args()
cfg = load_config(args)
cfg.USED_GPU = 7 #make sure it is same as device
vid_path = "/data/disk/cataracts_dataset/test_video/test01.mp4"

#tridet_cfg = load_tridet_config("/data/disk/LUO/test_only/TriDet/configs/cataract_slowfast.yaml")

#ckpt_file = '/data/disk/LUO/epoch_110_best.pth.tar'

tridet_cfg = load_tridet_config("/data/disk/LUO/test_only/TriDet/configs/cataract_slowfast_reverse.yaml")
ckpt_file = '/data/disk/LUO/test_only/TriDet/ckpt/cataract_slowfast_reverse_reverse/epoch_100.pth.tar'

topk=-1
print_freq = 10
save_only = False

# fix the random seeds (this will fix everything)
_ = fix_random_seed(0, include_cuda=True)


In [50]:
device = "cuda:7"
cap = cv2.VideoCapture(vid_path)
sample_height = 256
sample_width = 256
count = 0
seq_length = 32
queue = deque(maxlen=32)
ret = True
frame_wise_feature = None
extract_model = build_model(cfg).to(device)
print("loading feature extraction model")
extract_checkpoint = torch.load("/data/disk/LUO/slowfast/feature_extract/checkpoints/checkpoint_epoch_00130.pyth",map_location=device)
extract_model.load_state_dict(extract_checkpoint['model_state'])
del extract_checkpoint
idx = int(vid_path.split("/")[-1].split(".")[0].split("test")[-1])-1 #test01 -> 0
dict_db,label_dict = _load_json() #dict_db contains segments and labels
current_db = dict_db[idx] #current test file json
count_2 = 0 #count after queue is full
num_fill = 24
# load tridet model
model = make_meta_arch(tridet_cfg['model_name'], **tridet_cfg['model'])
# not ideal for multi GPU training, ok for now
model = nn.DataParallel(model, device_ids=tridet_cfg['devices'])
print("=> loading checkpoint '{}'".format(ckpt_file))
# load ckpt, reset epoch / best rmse
checkpoint = torch.load(
    ckpt_file,
    map_location=lambda storage, loc: storage.cuda(tridet_cfg['devices'][0])
)

# load ema model instead
print("Loading from EMA model ...")
model.load_state_dict(checkpoint['state_dict_ema'])
del checkpoint
pure_pred_list = []
fill_pred_list = []
pred_list = []
start = datetime.datetime.now()
while ret:
    #print("count: ",count)
    
    ret,frame = cap.read()
    if not ret:
        print("finished")
        break
    height, width = frame.shape[:2]
    frame = cv2.resize(frame,(sample_height,sample_width),interpolation = cv2.INTER_LINEAR)
    queue.append(frame)
    if count<32:
        count+=1
        continue
    else:
        
        frames = np.stack(queue, axis=0)
        frames = pre_process_frame(frames)
        frame_list = pack_pathway_output(cfg, frames) #得到一个list,list[0]为(3,8,256,256),[1]为[3,32,256,256]分别表示slow和fast两条线
        frame_list[0] = frame_list[0].unsqueeze(0) #增加一个batch维度
        frame_list[1] = frame_list[1].unsqueeze(0)
        if count_2%30 == 0: #每30表示的是每30个frame，即每一秒取一次feature
            cur_second = count//30
            print("current second: ",cur_second)
            if frame_wise_feature is None:
                frame_wise_feature = test(extract_model,cfg,frame_list) #得到每32个frame为单位的一个feature，size大小为(1,2304)
                #print("frame_wise_feture",frame_wise_feture.shape)
                #print(test(cfg,frame_list))
            else:
                #print("frame_wise_feture ",frame_wise_feture.shape)
                cur_feature = test(extract_model,cfg,frame_list)
                #print("cur_feature ",cur_feature.shape)
                frame_wise_feature = np.concatenate((frame_wise_feature,cur_feature),axis=0)
                #print(frame_wise_feture.shape)
                #将frame_wise_feature 在dim 0 上concate起来，放入surgplan的eval方程
            
            
            #transfer input frame_wise_features into dictionary
            inputs = getitem(data_list=dict_db,features=frame_wise_feature, idx=idx,num_frames=len(frame_wise_feature))
            results = valid_one_epoch([inputs],model=model)
            df = to_df(results)
            pred_label,score = get_middle_label(df,[cur_second-2,cur_second],threshold=0.22)
            pure_pred_list.append(pred_label)
            print("no mirror predicted label: ",pred_label," score:",score)
            #ret = False
            #inference

            filled_feature = fill_feature(inputs,num_fill)
            results = valid_one_epoch([filled_feature],model=model)
            df = to_df(results)
            pred_label,score = get_middle_label(df,[cur_second-2,cur_second+num_fill],threshold=0.22)
            fill_pred_list.append(pred_label)
            print("filled predicted label: ",pred_label," score:",score)

            mirror_features = mirror_feature(inputs,num_fill)
            results = valid_one_epoch([mirror_features],model=model)
            df = to_df(results)
            if cur_second == 400:
                print("400")
            if inputs["feats"].size(1) >= 1024:
                pred_label,score = get_middle_label(df,[512-num_fill//2,512+num_fill//2],0.22)
            else:
                pred_label,score = get_middle_label(df,[cur_second+num_fill-2,cur_second+num_fill],threshold=0.22)

            print("mirror predicted label: ",pred_label," score:",score)
            print()
            pred_list.append(pred_label)
        count+=1
        count_2 +=1 
end = datetime.datetime.now()

print("total time: ", (end-start).seconds)



loading feature extraction model
=> loading checkpoint '/data/disk/LUO/test_only/TriDet/ckpt/cataract_slowfast_reverse_reverse/epoch_100.pth.tar'
Loading from EMA model ...
current second:  1
no mirror predicted label:  0  score: 0.22
filled predicted label:  0  score: 0.22
mirror predicted label:  0  score: 0.22

current second:  2
no mirror predicted label:  0  score: 0.22
filled predicted label:  0  score: 0.22
mirror predicted label:  0  score: 0.22

current second:  3
no mirror predicted label:  0  score: 0.22
filled predicted label:  0  score: 0.22
mirror predicted label:  0  score: 0.22

current second:  4
no mirror predicted label:  0  score: 0.22
filled predicted label:  0  score: 0.22
mirror predicted label:  0  score: 0.22

current second:  5
no mirror predicted label:  0  score: 0.22
filled predicted label:  0  score: 0.22
mirror predicted label:  0  score: 0.22

current second:  6
no mirror predicted label:  0  score: 0.22
filled predicted label:  0  score: 0.22
mirror pre

In [66]:
df[df["score"]>=0.22]

Unnamed: 0,video-id,t-start,t-end,label,score
0,test01,23.998512,139.150421,10,0.90111
1,test01,916.218018,1024.0,10,0.894918
2,test01,306.380035,397.245392,15,0.772213
3,test01,659.538025,741.631104,15,0.762264
4,test01,150.395798,155.258194,4,0.758918
5,test01,899.757202,904.679565,4,0.690672
6,test01,510.535278,539.998108,11,0.666442
7,test01,783.475525,786.212402,4,0.639494
8,test01,268.491913,271.027039,4,0.614927
9,test01,290.328003,302.472351,14,0.538024


In [51]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score

In [52]:
gt = pd.read_csv("/data/disk/LUO/test_only/r_tridet/TriDet/df_gt.csv")

In [59]:
accuracy_score(gt[gt["video_id"]=="test01.mp4"]["gt_labels"][-len(pred_list):],pure_pred_list[:])

0.5099601593625498

In [314]:
x = gt[gt["video_id"]=="test01.mp4"][-len(pred_list):]

In [315]:
x["pred_labels"] = pred_list

In [316]:
x.to_csv("result_real_time.csv")

In [249]:
#Feature extraction configs
args = parse_args()
cfg = load_config(args)
cfg.USED_GPU = 4 #make sure it is same as device

#tridet_cfg = load_tridet_config("/data/disk/LUO/test_only/TriDet/configs/cataract_slowfast.yaml")
#ckpt_file = '/data/disk/LUO/epoch_110_best.pth.tar'

tridet_cfg = load_tridet_config("/data/disk/LUO/test_only/TriDet/configs/cataract_slowfast_reverse.yaml")
ckpt_file = '/data/disk/LUO/test_only/TriDet/ckpt/cataract_slowfast_reverse_reverse/epoch_080.pth.tar'

topk=-1
print_freq = 10
save_only = False

# fix the random seeds (this will fix everything)
_ = fix_random_seed(0, include_cuda=True)


In [250]:
extract_model = build_model(cfg).to(device)
extract_checkpoint = torch.load("/data/disk/LUO/slowfast/feature_extract/checkpoints/checkpoint_epoch_00130.pyth",map_location=device)
print("Loading from Feature Extraction Model ...")
extract_model.load_state_dict(extract_checkpoint['model_state'])
del extract_checkpoint
model = make_meta_arch(tridet_cfg['model_name'], **tridet_cfg['model'])
model = nn.DataParallel(model, device_ids=tridet_cfg['devices'])
print("=> loading checkpoint '{}'".format(ckpt_file))
checkpoint = torch.load(
    ckpt_file,
    map_location=lambda storage, loc: storage.cuda(tridet_cfg['devices'][0])
)
print("loading finish")
# load ema model instead
print("Loading from EMA model ...")
model.load_state_dict(checkpoint['state_dict_ema'])
del checkpoint
print("loading finish")
gt = pd.read_csv("/data/disk/LUO/test_only/r_tridet/TriDet/df_gt.csv")
dir_name = "/data/disk/LUO/cataract_test_video"

Loading from Feature Extraction Model ...
=> loading checkpoint '/data/disk/LUO/test_only/TriDet/ckpt/cataract_slowfast_reverse_reverse/epoch_080.pth.tar'
loading finish
Loading from EMA model ...
loading finish


In [369]:

def main(vid,log_folder,threshold=0.1,num_fill=32):    
    print("Eval video: ",vid," ...")
    vid_path = os.path.join(dir_name,vid)
    device = "cuda:4"
    cap = cv2.VideoCapture(vid_path)
    sample_height = 256
    sample_width = 256
    count = 0
    seq_length = 32
    queue = deque(maxlen=32)
    ret = True
    frame_wise_feature = None
    idx = int(vid_path.split("/")[-1].split(".")[0].split("test")[-1])-1 #test01 -> 0
    dict_db,label_dict = _load_json() #dict_db contains segments and labels
    current_db = dict_db[idx] #current test file json
    count_2 = 0 #count after queue is full
    # load tridet model
    # not ideal for multi GPU training, ok for now

    # load ckpt, reset epoch / best rmse


    pred_list = []
    score_list = []
    start = datetime.datetime.now()
    while ret:
        #print("count: ",count)
        
        ret,frame = cap.read()
        if not ret:
            print("finished")
            break
        height, width = frame.shape[:2]
        frame = cv2.resize(frame,(sample_height,sample_width),interpolation = cv2.INTER_LINEAR)
        queue.append(frame)
        if count<32:
            count+=1
            continue
        else:
            
            frames = np.stack(queue, axis=0)
            frames = pre_process_frame(frames)
            frame_list = pack_pathway_output(cfg, frames) #得到一个list,list[0]为(3,8,256,256),[1]为[3,32,256,256]分别表示slow和fast两条线
            frame_list[0] = frame_list[0].unsqueeze(0) #增加一个batch维度
            frame_list[1] = frame_list[1].unsqueeze(0)
            if count_2%30 == 0: #每30表示的是每30个frame，即每一秒取一次feature
                cur_second = count//30
                #print("current second: ",cur_second)
                if frame_wise_feature is None:
                    frame_wise_feature = test(extract_model,cfg,frame_list) #得到每32个frame为单位的一个feature，size大小为(1,2304)
                    #print("frame_wise_feture",frame_wise_feture.shape)
                    #print(test(cfg,frame_list))
                else:
                    #print("frame_wise_feture ",frame_wise_feture.shape)
                    cur_feature = test(extract_model,cfg,frame_list)
                    #print("cur_feature ",cur_feature.shape)
                    frame_wise_feature = np.concatenate((frame_wise_feature,cur_feature),axis=0)
                    #print(frame_wise_feture.shape)
                    #将frame_wise_feature 在dim 0 上concate起来，放入surgplan的eval方程
                
                
                #transfer input frame_wise_features into dictionary
                inputs = getitem(data_list=dict_db,features=frame_wise_feature, idx=idx,num_frames=len(frame_wise_feature))
                #ret = False
                #inference
                results = valid_one_epoch([mirror_feature(inputs,num_fill)],model=model)
                df = to_df(results)
                if inputs["feats"].size(1) == 1024:
                    pred_label,score = get_middle_label(df,[512-num_fill//2,512+num_fill//2],0.1)
                else:
                    pred_label,score = get_middle_label(df,[cur_second+num_fill-2,cur_second+num_fill],threshold)
                #print("predicted label: ",pred_label," score:",score)
                pred_list.append(pred_label)
                score_list.append(score)
            count+=1
            count_2 +=1 
    end = datetime.datetime.now()
    cap.release()
    print("total time: ", (end-start).seconds)
    acc = accuracy_score(gt[gt["video_id"]==vid]["gt_labels"][-len(pred_list):],pred_list[:])
    print("accuracy for vid: ",vid," is : ",acc)
    
    x = gt[gt["video_id"]==vid][-len(pred_list):]
    x["pred_labels"] = pred_list
    x["pred_scores"] = score_list

    csv_file = "./"+log_folder+"/result_real_time_"+vid+".csv"
    x.to_csv(csv_file)

In [370]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.18,num_fill=24)

Eval video:  test01.mp4  ...
finished
total time:  331
accuracy for vid:  test01.mp4  is :  0.6427622841965471
Eval video:  test02.mp4  ...
finished
total time:  845
accuracy for vid:  test02.mp4  is :  0.4849137931034483
Eval video:  test03.mp4  ...
finished
total time:  221
accuracy for vid:  test03.mp4  is :  0.852882703777336
Eval video:  test04.mp4  ...
finished
total time:  212
accuracy for vid:  test04.mp4  is :  0.7577639751552795
Eval video:  test05.mp4  ...
finished
total time:  208
accuracy for vid:  test05.mp4  is :  0.7398720682302772
Eval video:  test06.mp4  ...
finished
total time:  252
accuracy for vid:  test06.mp4  is :  0.679646017699115
Eval video:  test07.mp4  ...
finished
total time:  470
accuracy for vid:  test07.mp4  is :  0.5480769230769231
Eval video:  test08.mp4  ...
finished
total time:  191
accuracy for vid:  test08.mp4  is :  0.744920993227991
Eval video:  test09.mp4  ...
finished
total time:  282
accuracy for vid:  test09.mp4  is :  0.751937984496124
Eval 

In [373]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.15,num_fill=32)


Eval video:  test01.mp4  ...
finished
total time:  321
accuracy for vid:  test01.mp4  is :  0.7224435590969456
Eval video:  test02.mp4  ...
finished
total time:  803
accuracy for vid:  test02.mp4  is :  0.5102370689655172
Eval video:  test03.mp4  ...
finished
total time:  211
accuracy for vid:  test03.mp4  is :  0.852882703777336
Eval video:  test04.mp4  ...
finished
total time:  198
accuracy for vid:  test04.mp4  is :  0.7929606625258799
Eval video:  test05.mp4  ...
finished
total time:  204
accuracy for vid:  test05.mp4  is :  0.767590618336887
Eval video:  test06.mp4  ...
finished
total time:  236
accuracy for vid:  test06.mp4  is :  0.6991150442477876
Eval video:  test07.mp4  ...
finished
total time:  429
accuracy for vid:  test07.mp4  is :  0.5711538461538461
Eval video:  test08.mp4  ...
finished
total time:  191
accuracy for vid:  test08.mp4  is :  0.781038374717833
Eval video:  test09.mp4  ...
finished
total time:  281
accuracy for vid:  test09.mp4  is :  0.7441860465116279
Eval

In [None]:
ckpt_file = '/data/disk/LUO/test_only/TriDet/ckpt/cataract_slowfast_reverse_reverse/epoch_050.pth.tar'

extract_model = build_model(cfg).to(device)
extract_checkpoint = torch.load("/data/disk/LUO/slowfast/feature_extract/checkpoints/checkpoint_epoch_00130.pyth",map_location=device)
print("Loading from Feature Extraction Model ...")
extract_model.load_state_dict(extract_checkpoint['model_state'])
del extract_checkpoint
model = make_meta_arch(tridet_cfg['model_name'], **tridet_cfg['model'])
model = nn.DataParallel(model, device_ids=tridet_cfg['devices'])
print("=> loading checkpoint '{}'".format(ckpt_file))
checkpoint = torch.load(
    ckpt_file,
    map_location=lambda storage, loc: storage.cuda(tridet_cfg['devices'][0])
)
print("loading finish")
# load ema model instead
print("Loading from EMA model ...")
model.load_state_dict(checkpoint['state_dict_ema'])
del checkpoint
print("loading finish")
gt = pd.read_csv("/data/disk/LUO/test_only/r_tridet/TriDet/df_gt.csv")
dir_name = "/data/disk/LUO/cataract_test_video"



In [374]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.135,num_fill=32)

Eval video:  test01.mp4  ...
finished
total time:  320
accuracy for vid:  test01.mp4  is :  0.7250996015936255
Eval video:  test02.mp4  ...
finished
total time:  808
accuracy for vid:  test02.mp4  is :  0.5043103448275862
Eval video:  test03.mp4  ...
finished
total time:  203
accuracy for vid:  test03.mp4  is :  0.8290258449304175
Eval video:  test04.mp4  ...
finished
total time:  206
accuracy for vid:  test04.mp4  is :  0.7743271221532091
Eval video:  test05.mp4  ...
finished
total time:  197
accuracy for vid:  test05.mp4  is :  0.7654584221748401
Eval video:  test06.mp4  ...
finished
total time:  234
accuracy for vid:  test06.mp4  is :  0.6761061946902654
Eval video:  test07.mp4  ...
finished
total time:  438
accuracy for vid:  test07.mp4  is :  0.5634615384615385
Eval video:  test08.mp4  ...
finished
total time:  176
accuracy for vid:  test08.mp4  is :  0.7720090293453724
Eval video:  test09.mp4  ...
finished
total time:  269
accuracy for vid:  test09.mp4  is :  0.7286821705426356
E

In [375]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.15,num_fill=28)

Eval video:  test01.mp4  ...
finished
total time:  299
accuracy for vid:  test01.mp4  is :  0.7184594953519257
Eval video:  test02.mp4  ...
finished
total time:  739
accuracy for vid:  test02.mp4  is :  0.5129310344827587
Eval video:  test03.mp4  ...
finished
total time:  201
accuracy for vid:  test03.mp4  is :  0.8449304174950298
Eval video:  test04.mp4  ...
finished
total time:  188
accuracy for vid:  test04.mp4  is :  0.7784679089026915
Eval video:  test05.mp4  ...
finished
total time:  184
accuracy for vid:  test05.mp4  is :  0.7547974413646056
Eval video:  test06.mp4  ...
finished
total time:  231
accuracy for vid:  test06.mp4  is :  0.6831858407079646
Eval video:  test07.mp4  ...
finished
total time:  435
accuracy for vid:  test07.mp4  is :  0.5663461538461538
Eval video:  test08.mp4  ...
finished
total time:  185
accuracy for vid:  test08.mp4  is :  0.7539503386004515
Eval video:  test09.mp4  ...
finished
total time:  268
accuracy for vid:  test09.mp4  is :  0.7286821705426356
E

In [376]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.15,num_fill=36)

Eval video:  test01.mp4  ...
finished
total time:  314
accuracy for vid:  test01.mp4  is :  0.7237715803452855
Eval video:  test02.mp4  ...
finished
total time:  789
accuracy for vid:  test02.mp4  is :  0.5021551724137931
Eval video:  test03.mp4  ...
finished
total time:  205
accuracy for vid:  test03.mp4  is :  0.8608349900596421
Eval video:  test04.mp4  ...
finished
total time:  203
accuracy for vid:  test04.mp4  is :  0.7950310559006211
Eval video:  test05.mp4  ...
finished
total time:  187
accuracy for vid:  test05.mp4  is :  0.7611940298507462
Eval video:  test06.mp4  ...
finished
total time:  239
accuracy for vid:  test06.mp4  is :  0.7150442477876107
Eval video:  test07.mp4  ...
finished
total time:  422
accuracy for vid:  test07.mp4  is :  0.5682692307692307
Eval video:  test08.mp4  ...
finished
total time:  179
accuracy for vid:  test08.mp4  is :  0.8103837471783296
Eval video:  test09.mp4  ...
finished
total time:  269
accuracy for vid:  test09.mp4  is :  0.737984496124031
Ev

In [377]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.15,num_fill=40)

Eval video:  test01.mp4  ...
finished
total time:  303
accuracy for vid:  test01.mp4  is :  0.7197875166002656
Eval video:  test02.mp4  ...
finished
total time:  765
accuracy for vid:  test02.mp4  is :  0.5026939655172413
Eval video:  test03.mp4  ...
finished
total time:  207
accuracy for vid:  test03.mp4  is :  0.8628230616302187
Eval video:  test04.mp4  ...
finished
total time:  195
accuracy for vid:  test04.mp4  is :  0.7867494824016563
Eval video:  test05.mp4  ...
finished
total time:  188
accuracy for vid:  test05.mp4  is :  0.767590618336887
Eval video:  test06.mp4  ...
finished
total time:  228
accuracy for vid:  test06.mp4  is :  0.7380530973451327
Eval video:  test07.mp4  ...
finished
total time:  402
accuracy for vid:  test07.mp4  is :  0.5769230769230769
Eval video:  test08.mp4  ...
finished
total time:  174
accuracy for vid:  test08.mp4  is :  0.8261851015801355
Eval video:  test09.mp4  ...
finished
total time:  261
accuracy for vid:  test09.mp4  is :  0.7441860465116279
Ev

In [378]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.165,num_fill=32)

Eval video:  test01.mp4  ...
finished
total time:  277
accuracy for vid:  test01.mp4  is :  0.6759628154050464
Eval video:  test02.mp4  ...
finished
total time:  675
accuracy for vid:  test02.mp4  is :  0.5
Eval video:  test03.mp4  ...
finished
total time:  185
accuracy for vid:  test03.mp4  is :  0.8588469184890656
Eval video:  test04.mp4  ...
finished
total time:  174
accuracy for vid:  test04.mp4  is :  0.7991718426501035
Eval video:  test05.mp4  ...
finished
total time:  174
accuracy for vid:  test05.mp4  is :  0.7718550106609808
Eval video:  test06.mp4  ...
finished
total time:  206
accuracy for vid:  test06.mp4  is :  0.7044247787610619
Eval video:  test07.mp4  ...
finished
total time:  368
accuracy for vid:  test07.mp4  is :  0.573076923076923
Eval video:  test08.mp4  ...
finished
total time:  166
accuracy for vid:  test08.mp4  is :  0.7900677200902935
Eval video:  test09.mp4  ...
finished
total time:  235
accuracy for vid:  test09.mp4  is :  0.7488372093023256
Eval video:  test

In [379]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.165,num_fill=40)

Eval video:  test01.mp4  ...
finished
total time:  264
accuracy for vid:  test01.mp4  is :  0.6613545816733067
Eval video:  test02.mp4  ...
finished
total time:  683
accuracy for vid:  test02.mp4  is :  0.4978448275862069
Eval video:  test03.mp4  ...
finished
total time:  179
accuracy for vid:  test03.mp4  is :  0.8667992047713717
Eval video:  test04.mp4  ...
finished
total time:  174
accuracy for vid:  test04.mp4  is :  0.7991718426501035
Eval video:  test05.mp4  ...
finished
total time:  165
accuracy for vid:  test05.mp4  is :  0.7654584221748401
Eval video:  test06.mp4  ...
finished
total time:  208
accuracy for vid:  test06.mp4  is :  0.7451327433628319
Eval video:  test07.mp4  ...
finished
total time:  380
accuracy for vid:  test07.mp4  is :  0.5798076923076924
Eval video:  test08.mp4  ...
finished
total time:  159
accuracy for vid:  test08.mp4  is :  0.8397291196388262
Eval video:  test09.mp4  ...
finished
total time:  233
accuracy for vid:  test09.mp4  is :  0.7503875968992249
E

In [380]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.13,num_fill=42)

Eval video:  test01.mp4  ...
finished
total time:  311
accuracy for vid:  test01.mp4  is :  0.7237715803452855
Eval video:  test02.mp4  ...
finished
total time:  780
accuracy for vid:  test02.mp4  is :  0.5005387931034483
Eval video:  test03.mp4  ...
finished
total time:  197
accuracy for vid:  test03.mp4  is :  0.8369781312127237
Eval video:  test04.mp4  ...
finished
total time:  190
accuracy for vid:  test04.mp4  is :  0.7929606625258799
Eval video:  test05.mp4  ...
finished
total time:  184
accuracy for vid:  test05.mp4  is :  0.7611940298507462
Eval video:  test06.mp4  ...
finished
total time:  223
accuracy for vid:  test06.mp4  is :  0.7097345132743362
Eval video:  test07.mp4  ...
finished
total time:  417
accuracy for vid:  test07.mp4  is :  0.5855769230769231
Eval video:  test08.mp4  ...
finished
total time:  179
accuracy for vid:  test08.mp4  is :  0.8171557562076749
Eval video:  test09.mp4  ...
finished
total time:  257
accuracy for vid:  test09.mp4  is :  0.7286821705426356
E

In [381]:

def main(vid,log_folder,threshold=0.1,num_fill=32):    
    print("Eval video: ",vid," ...")
    vid_path = os.path.join(dir_name,vid)
    device = "cuda:4"
    cap = cv2.VideoCapture(vid_path)
    sample_height = 256
    sample_width = 256
    count = 0
    seq_length = 32
    queue = deque(maxlen=32)
    ret = True
    frame_wise_feature = None
    idx = int(vid_path.split("/")[-1].split(".")[0].split("test")[-1])-1 #test01 -> 0
    dict_db,label_dict = _load_json() #dict_db contains segments and labels
    current_db = dict_db[idx] #current test file json
    count_2 = 0 #count after queue is full
    # load tridet model
    # not ideal for multi GPU training, ok for now

    # load ckpt, reset epoch / best rmse


    pred_list = []
    score_list = []
    start = datetime.datetime.now()
    while ret:
        #print("count: ",count)
        
        ret,frame = cap.read()
        if not ret:
            print("finished")
            break
        height, width = frame.shape[:2]
        frame = cv2.resize(frame,(sample_height,sample_width),interpolation = cv2.INTER_LINEAR)
        queue.append(frame)
        if count<32:
            count+=1
            continue
        else:
            
            frames = np.stack(queue, axis=0)
            frames = pre_process_frame(frames)
            frame_list = pack_pathway_output(cfg, frames) #得到一个list,list[0]为(3,8,256,256),[1]为[3,32,256,256]分别表示slow和fast两条线
            frame_list[0] = frame_list[0].unsqueeze(0) #增加一个batch维度
            frame_list[1] = frame_list[1].unsqueeze(0)
            if count_2%30 == 0: #每30表示的是每30个frame，即每一秒取一次feature
                cur_second = count//30
                #print("current second: ",cur_second)
                if frame_wise_feature is None:
                    frame_wise_feature = test(extract_model,cfg,frame_list) #得到每32个frame为单位的一个feature，size大小为(1,2304)
                    #print("frame_wise_feture",frame_wise_feture.shape)
                    #print(test(cfg,frame_list))
                else:
                    #print("frame_wise_feture ",frame_wise_feture.shape)
                    cur_feature = test(extract_model,cfg,frame_list)
                    #print("cur_feature ",cur_feature.shape)
                    frame_wise_feature = np.concatenate((frame_wise_feature,cur_feature),axis=0)
                    #print(frame_wise_feture.shape)
                    #将frame_wise_feature 在dim 0 上concate起来，放入surgplan的eval方程
                
                
                #transfer input frame_wise_features into dictionary
                inputs = getitem(data_list=dict_db,features=frame_wise_feature, idx=idx,num_frames=len(frame_wise_feature))
                #ret = False
                #inference
                results = valid_one_epoch([mirror_feature(inputs,num_fill)],model=model)
                df = to_df(results)
                if inputs["feats"].size(1) >= 1024:
                    pred_label,score = get_middle_label(df,[512-num_fill//2,512+num_fill//2],threshold)
                else:
                    pred_label,score = get_middle_label(df,[cur_second+num_fill-2,cur_second+num_fill],threshold)
                #print("predicted label: ",pred_label," score:",score)
                pred_list.append(pred_label)
                score_list.append(score)
            count+=1
            count_2 +=1 
    end = datetime.datetime.now()
    cap.release()
    print("total time: ", (end-start).seconds)
    acc = accuracy_score(gt[gt["video_id"]==vid]["gt_labels"][-len(pred_list):],pred_list[:])
    print("accuracy for vid: ",vid," is : ",acc)
    
    x = gt[gt["video_id"]==vid][-len(pred_list):]
    x["pred_labels"] = pred_list
    x["pred_scores"] = score_list

    csv_file = "./"+log_folder+"/result_real_time_"+vid+".csv"
    x.to_csv(csv_file)

In [382]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
print(cur_time)
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.13,num_fill=42)

2024-02-01-12-03-06_result
Eval video:  test01.mp4  ...
finished
total time:  319
accuracy for vid:  test01.mp4  is :  0.7237715803452855
Eval video:  test02.mp4  ...
finished
total time:  763
accuracy for vid:  test02.mp4  is :  0.5005387931034483
Eval video:  test03.mp4  ...
finished
total time:  209
accuracy for vid:  test03.mp4  is :  0.8369781312127237
Eval video:  test04.mp4  ...
finished
total time:  204
accuracy for vid:  test04.mp4  is :  0.7929606625258799
Eval video:  test05.mp4  ...
finished
total time:  186
accuracy for vid:  test05.mp4  is :  0.7611940298507462
Eval video:  test06.mp4  ...
finished
total time:  235
accuracy for vid:  test06.mp4  is :  0.7097345132743362
Eval video:  test07.mp4  ...
finished
total time:  426
accuracy for vid:  test07.mp4  is :  0.5855769230769231
Eval video:  test08.mp4  ...
finished
total time:  183
accuracy for vid:  test08.mp4  is :  0.8171557562076749
Eval video:  test09.mp4  ...
finished
total time:  258
accuracy for vid:  test09.mp4 

In [383]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
print(cur_time)
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.1,num_fill=40)

2024-02-01-13-50-27_result
Eval video:  test01.mp4  ...
finished
total time:  393
accuracy for vid:  test01.mp4  is :  0.7184594953519257
Eval video:  test02.mp4  ...
finished
total time:  966
accuracy for vid:  test02.mp4  is :  0.47413793103448276
Eval video:  test03.mp4  ...
finished
total time:  267
accuracy for vid:  test03.mp4  is :  0.8170974155069582
Eval video:  test04.mp4  ...
finished
total time:  241
accuracy for vid:  test04.mp4  is :  0.7784679089026915
Eval video:  test05.mp4  ...
finished
total time:  228
accuracy for vid:  test05.mp4  is :  0.7313432835820896
Eval video:  test06.mp4  ...
finished
total time:  265
accuracy for vid:  test06.mp4  is :  0.6814159292035398
Eval video:  test07.mp4  ...
finished
total time:  518
accuracy for vid:  test07.mp4  is :  0.5855769230769231
Eval video:  test08.mp4  ...
finished
total time:  226
accuracy for vid:  test08.mp4  is :  0.7765237020316027
Eval video:  test09.mp4  ...
finished
total time:  321
accuracy for vid:  test09.mp4

In [384]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
print(cur_time)
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.15,num_fill=40)

2024-02-01-17-30-33_result
Eval video:  test01.mp4  ...
finished
total time:  1019
accuracy for vid:  test01.mp4  is :  0.7197875166002656
Eval video:  test02.mp4  ...
finished
total time:  2612
accuracy for vid:  test02.mp4  is :  0.5026939655172413
Eval video:  test03.mp4  ...
finished
total time:  651
accuracy for vid:  test03.mp4  is :  0.8628230616302187
Eval video:  test04.mp4  ...
finished
total time:  634
accuracy for vid:  test04.mp4  is :  0.7867494824016563
Eval video:  test05.mp4  ...
finished
total time:  549
accuracy for vid:  test05.mp4  is :  0.767590618336887
Eval video:  test06.mp4  ...
finished
total time:  714
accuracy for vid:  test06.mp4  is :  0.7380530973451327
Eval video:  test07.mp4  ...
finished
total time:  1418
accuracy for vid:  test07.mp4  is :  0.5769230769230769
Eval video:  test08.mp4  ...
finished
total time:  594
accuracy for vid:  test08.mp4  is :  0.8261851015801355
Eval video:  test09.mp4  ...
finished
total time:  973
accuracy for vid:  test09.mp

In [385]:

def main(vid,log_folder,threshold=0.1,num_fill=32):    
    print("Eval video: ",vid," ...")
    vid_path = os.path.join(dir_name,vid)
    device = "cuda:4"
    cap = cv2.VideoCapture(vid_path)
    sample_height = 256
    sample_width = 256
    count = 0
    seq_length = 32
    queue = deque(maxlen=32)
    ret = True
    frame_wise_feature = None
    idx = int(vid_path.split("/")[-1].split(".")[0].split("test")[-1])-1 #test01 -> 0
    dict_db,label_dict = _load_json() #dict_db contains segments and labels
    current_db = dict_db[idx] #current test file json
    count_2 = 0 #count after queue is full
    # load tridet model
    # not ideal for multi GPU training, ok for now

    # load ckpt, reset epoch / best rmse


    pred_list = []
    score_list = []
    start = datetime.datetime.now()
    while ret:
        #print("count: ",count)
        
        ret,frame = cap.read()
        if not ret:
            print("finished")
            break
        height, width = frame.shape[:2]
        frame = cv2.resize(frame,(sample_height,sample_width),interpolation = cv2.INTER_LINEAR)
        queue.append(frame)
        if count<32:
            count+=1
            continue
        else:
            
            frames = np.stack(queue, axis=0)
            frames = pre_process_frame(frames)
            frame_list = pack_pathway_output(cfg, frames) #得到一个list,list[0]为(3,8,256,256),[1]为[3,32,256,256]分别表示slow和fast两条线
            frame_list[0] = frame_list[0].unsqueeze(0) #增加一个batch维度
            frame_list[1] = frame_list[1].unsqueeze(0)
            if count_2%30 == 0: #每30表示的是每30个frame，即每一秒取一次feature
                cur_second = count//30
                #print("current second: ",cur_second)
                if frame_wise_feature is None:
                    frame_wise_feature = test(extract_model,cfg,frame_list) #得到每32个frame为单位的一个feature，size大小为(1,2304)
                    #print("frame_wise_feture",frame_wise_feture.shape)
                    #print(test(cfg,frame_list))
                else:
                    #print("frame_wise_feture ",frame_wise_feture.shape)
                    cur_feature = test(extract_model,cfg,frame_list)
                    #print("cur_feature ",cur_feature.shape)
                    frame_wise_feature = np.concatenate((frame_wise_feature,cur_feature),axis=0)
                    #print(frame_wise_feture.shape)
                    #将frame_wise_feature 在dim 0 上concate起来，放入surgplan的eval方程
                
                
                #transfer input frame_wise_features into dictionary
                inputs = getitem(data_list=dict_db,features=frame_wise_feature, idx=idx,num_frames=len(frame_wise_feature))
                #ret = False
                #inference
                results = valid_one_epoch([mirror_feature(inputs,num_fill)],model=model)
                df = to_df(results)
                if inputs["feats"].size(1) >= 1024:
                    pred_label,score = get_middle_label(df,[512-num_fill//2,512+num_fill//2],threshold)
                else:
                    pred_label,score = get_middle_label(df,[cur_second,cur_second+num_fill],threshold)
                #print("predicted label: ",pred_label," score:",score)
                pred_list.append(pred_label)
                score_list.append(score)
            count+=1
            count_2 +=1 
    end = datetime.datetime.now()
    cap.release()
    print("total time: ", (end-start).seconds)
    acc = accuracy_score(gt[gt["video_id"]==vid]["gt_labels"][-len(pred_list):],pred_list[:])
    print("accuracy for vid: ",vid," is : ",acc)
    
    x = gt[gt["video_id"]==vid][-len(pred_list):]
    x["pred_labels"] = pred_list
    x["pred_scores"] = score_list

    csv_file = "./"+log_folder+"/result_real_time_"+vid+".csv"
    x.to_csv(csv_file)

In [386]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
print(cur_time)
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.13,num_fill=42)

2024-02-01-23-13-44_result
Eval video:  test01.mp4  ...
finished
total time:  993
accuracy for vid:  test01.mp4  is :  0.6786188579017264
Eval video:  test02.mp4  ...
finished
total time:  2422
accuracy for vid:  test02.mp4  is :  0.5032327586206896
Eval video:  test03.mp4  ...
finished
total time:  652
accuracy for vid:  test03.mp4  is :  0.7137176938369781
Eval video:  test04.mp4  ...
finished
total time:  636
accuracy for vid:  test04.mp4  is :  0.6231884057971014
Eval video:  test05.mp4  ...
finished
total time:  625
accuracy for vid:  test05.mp4  is :  0.6204690831556503
Eval video:  test06.mp4  ...
finished
total time:  772
accuracy for vid:  test06.mp4  is :  0.6070796460176991
Eval video:  test07.mp4  ...
finished
total time:  1375
accuracy for vid:  test07.mp4  is :  0.5144230769230769
Eval video:  test08.mp4  ...
finished
total time:  558
accuracy for vid:  test08.mp4  is :  0.6433408577878104
Eval video:  test09.mp4  ...
finished
total time:  769
accuracy for vid:  test09.mp

In [387]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
print(cur_time)
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.1,num_fill=40)

2024-02-02-04-49-33_result
Eval video:  test01.mp4  ...
finished
total time:  951
accuracy for vid:  test01.mp4  is :  0.6772908366533864
Eval video:  test02.mp4  ...
finished
total time:  2369
accuracy for vid:  test02.mp4  is :  0.47629310344827586
Eval video:  test03.mp4  ...
finished
total time:  658
accuracy for vid:  test03.mp4  is :  0.709741550695825
Eval video:  test04.mp4  ...
finished
total time:  560
accuracy for vid:  test04.mp4  is :  0.6128364389233955
Eval video:  test05.mp4  ...
finished
total time:  573
accuracy for vid:  test05.mp4  is :  0.5970149253731343
Eval video:  test06.mp4  ...
finished
total time:  699
accuracy for vid:  test06.mp4  is :  0.6070796460176991
Eval video:  test07.mp4  ...
finished
total time:  1341
accuracy for vid:  test07.mp4  is :  0.5221153846153846
Eval video:  test08.mp4  ...
finished
total time:  517
accuracy for vid:  test08.mp4  is :  0.6388261851015802
Eval video:  test09.mp4  ...
finished
total time:  767
accuracy for vid:  test09.mp

In [388]:
cur_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")+"_result"
print(cur_time)
os.mkdir(cur_time)
for i in os.listdir("/data/disk/LUO/cataract_test_video"):
    main(i,cur_time,threshold=0.15,num_fill=40)

2024-02-02-09-00-26_result
Eval video:  test01.mp4  ...
finished
total time:  420
accuracy for vid:  test01.mp4  is :  0.6772908366533864
Eval video:  test02.mp4  ...
finished
total time:  1059
accuracy for vid:  test02.mp4  is :  0.5053879310344828
Eval video:  test03.mp4  ...
finished
total time:  293
accuracy for vid:  test03.mp4  is :  0.7196819085487077
Eval video:  test04.mp4  ...
finished
total time:  290
accuracy for vid:  test04.mp4  is :  0.6252587991718427
Eval video:  test05.mp4  ...
finished
total time:  279
accuracy for vid:  test05.mp4  is :  0.6183368869936035
Eval video:  test06.mp4  ...
finished
total time:  330
accuracy for vid:  test06.mp4  is :  0.6212389380530974
Eval video:  test07.mp4  ...
finished
total time:  617
accuracy for vid:  test07.mp4  is :  0.5125
Eval video:  test08.mp4  ...
finished
total time:  246
accuracy for vid:  test08.mp4  is :  0.654627539503386
Eval video:  test09.mp4  ...
finished
total time:  351
accuracy for vid:  test09.mp4  is :  0.688

In [389]:
svrcn = pd.read_csv("/data/disk/video_detection/SVRC/experiment/20231109-233357__SGD_lr0.01_factor1.0_step5_gamma0.5_wd1e-06_noseed/result-best-8-acc72.11-f58.73-recall60.65-prec62.33-jmicro56.39/tmp_preds_gt.csv")

In [392]:
with open("/data/disk/LUO/test_only/r_tridet/TriDet/data/cataract/data_1102.json","r") as f:
    json_data = json.load(f)

In [395]:
with open("/data/disk/LUO/test_only/r_tridet/TriDet/data/cataract/data_1102_indent.json","w") as f:
    f.write(json.dumps(json_data,indent=2))

In [None]:
{"train24.mp4": {
    "annotation": [
      {
        "start": 11.433333333333334,
        "end": 41.233333333333334,
        "label": 1,
        "duration": 29.8,
        "subset": "training",
        "time_till_now": 41.233333333333334
      }]}}