In [2]:


#example use: python inference.py --vid_name=hyw7Ue6oW8w.mp4

print("Starting script!")

%cd /home/egoodman/multitaskmodel/scripts



DEFAULT_CATEGORIES = ['cutting', 'tying', 'suturing', 'background']
DETECTION_CLASSES = ['bovie', 'forceps', 'needledriver', 'hand']

import sys
sys.path.append("/home/egoodman/multitaskmodel/MULTITASK_FILES/TSM_FILES/")
sys.path.append('/home/egoodman/multitaskmodel/MULTITASK_FILES/RETINANET_FILES/src/pytorch-retinanet/')

from dataset import * #imports dataloaders from TSM
SurgeryDataset.categories = DEFAULT_CATEGORIES
from train import get_train_val_data_loaders, run_epoch
from model import get_model_name, save_model, save_results, get_model
from barbar import Bar

import torch.nn as nn
from torch.utils.data import DataLoader

import cv2
import utils
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt 
import numpy as np
import json
import csv
from collections import defaultdict
import shutil
import timeit
import argparse




def parse_args():
    parser = argparse.ArgumentParser(description='Surgery Hand and Keypoint Detection on Video')
    parser.add_argument('--vid_name', type=str, default="slap.mp4")
    parser.add_argument('--directory', type=str, default="./output/")
    parser.add_argument('--tool_model_loc', type=str, default="/home/egoodman/multitaskmodel/logs/best_models/20210427_bigmultitask_basicactionhead_352h64d_newinference_99_incomplete.pt")
    args = parser.parse_args("")
    return args


def get_test_data_loaders(segments_df, batch_size, data_dir='data/', model='TSM', pre_crop_size=352, segment_length=5,
                                                                    aug_method='val'):
    df = segments_df.sort_values(by=['video_id', 'start_seconds'])
    test_dataset = SurgeryDataset(df, data_dir=data_dir, mode='test', model=model, balance=False,
                                   pre_crop_size=pre_crop_size, aug_method=aug_method)
    test_data_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False,
                             num_workers=0, pin_memory=False)
    return test_data_loader



def get_video_path(video_id, data_dir='data/'):
    return os.path.join(data_dir + video_id + ".mp4")



def get_video_duration(filename):
    video = cv2.VideoCapture(filename)
    frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
    fps = video.get(cv2.CAP_PROP_FPS)
    return frame_count, fps



def anns_from_video_ids(video_ids, data_dir, segment_length, offset=0):
    print("Studying", video_ids)
    rows = []
    for video_id in video_ids:
        video_path = get_video_path(video_id, data_dir)
        print("Studying video at path", video_path)
        if not os.path.exists(video_path):
            print("Video not downloaded: %s" % video_id)
            continue
        frame_count, fps = get_video_duration(video_path)
        num_anns = int(frame_count / fps / segment_length)
        for i in range(num_anns):
            start_seconds = offset + i * segment_length
            label = 'background'
            row = {'start_seconds': start_seconds,
                   'video_id': video_id,
                   'end_seconds': start_seconds + segment_length,
                   'duration': segment_length,
                   'label': label,
                   'category': label}
            rows.append(row)
    anns_df = pd.DataFrame(rows)
    return anns_df










from retinanet import model_3_heads
from PIL import Image

from pycocotools.coco import COCO

import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision

print("We are using torch version", torch.__version__)
print("We are using torchvision version", torchvision.__version__)


#color scheme: hands(gold), bovie(red), needledriver(blue), forceps(green)
annot_to_color = {0 : (255, 0, 0), 1 : (0, 255, 0), 2 : (0, 0, 255), 3 : (255, 215, 0) }


def inference(test_data_loader, vid_name, tool_model_loc, directory):

    #used for measuring the inference time
    device = torch.device("cuda")
    retinanet = model_3_heads.resnet50(num_classes=4)
    retinanet = torch.load(tool_model_loc) 
    retinanet.to(device)

    tool_model = retinanet
    if torch.cuda.is_available():
        tool_model = tool_model.cuda()
        tool_model = torch.nn.DataParallel(tool_model).cuda()
        tool_model.eval()

    #this is just to get the dimensions of the dataset we're working with! only done at very beginning
    for iter_num, (data_action, record_ids, action_labels) in enumerate(test_data_loader):
        print("Grabbing dimensions")
        original_dimensions = record_ids[1]
        print("Original dimensions are", record_ids[1])
        break

    data_action = (data_action.view((-1, 3) + data_action.size()[-2:]))
    b, c, height, width = data_action.shape
    
    #video for superposition of detections/actions
    video = directory + vid_name
    video = cv2.VideoCapture(video)
    fps_video = video.get(cv2.CAP_PROP_FPS)
    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))

    #params for output video and json
    out_video = directory + vid_name[:-4] + "_detections.mp4"
    fps = 13 #frame rate defined by model
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_tracked = cv2.VideoWriter(out_video, fourcc, fps, (int(width), int(height)))
    output_json = defaultdict(list)

    #params for clean out video for tracking
    clean_out_video = directory + vid_name[:-4] + "_nodetections.mp4"
    fps = 13
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_clean = cv2.VideoWriter(clean_out_video, fourcc, fps, (int(width), int(height)))


    print('''

    Video has following parameters
    Frame Height : {}
    Frame Width : {}
    FPS : {}
    Number of Frames : {}
    Duration (seconds) : {}

    '''.format(height, width, fps_video, frame_count, frame_count/fps_video))

    start_time = timeit.default_timer()
    print("Starting inference on video {} at time {}".format(vid_name, start_time))

    for iter_num, (data_action, record_ids, action_labels) in enumerate(test_data_loader):
        print("Original dimensions are", record_ids[1])
        with torch.no_grad():
            
            #reshape data and forward pass!
            data_action = (data_action.view((-1, 3) + data_action.size()[-2:]))
            print("Studying batch", iter_num, data_action.shape)
            batch_nms_scores, batch_nms_class, batch_transformed_anchors, action_logits = tool_model(data_action)
            
            #go frame by frame through output and add to video
            for frame_no in range(len(batch_nms_scores)):
                frame_detections = []

                nms_scores = batch_nms_scores[frame_no]
                transformed_anchors = batch_transformed_anchors[frame_no]
                nms_class = batch_nms_class[frame_no]
                                
                idxs = np.where(nms_scores.cpu() >= .5)

                frame = data_action.squeeze().cuda().float()[frame_no, :, :, :].cpu().numpy()
                frame = np.transpose(frame, (1,2,0))
                frame = np.array(255*(frame.copy() *np.array([[[0.2650, 0.2877, 0.3311]]]) + np.array([[[0.3051, 0.3570, 0.4115]]])), dtype = np.uint8  )

                video_clean.write(np.uint8(frame))


                if len(idxs[0]) > 0:
                    for idx in idxs[0]:
                        a = float(transformed_anchors[idx].detach()[0])
                        b = float(transformed_anchors[idx].detach()[1])
                        c = float(transformed_anchors[idx].detach()[2])
                        d = float(transformed_anchors[idx].detach()[3])
                        e = float(nms_class[idx].detach()) #this last coordinate is the ID
                        
                        print("Drawing rectangle", a, b, c, d)

                        #draws the actual rectangle
                        cv2.rectangle(frame, (int(a), int(b)), (int(c), int(d)), color=annot_to_color[int(e)], thickness=3)
                        object_dict = {DETECTION_CLASSES[int(e)] : [a, b, c, d, float(nms_scores[idx])]}
                        frame_detections.append(object_dict)

                #used for labeling action
                cur_action = DEFAULT_CATEGORIES[int(torch.argmax(action_logits))]
                cur_action_prob = torch.max(action_logits)

                actions_vis = True
                if actions_vis:
                    cv2.putText(frame, cur_action + " " + str(float(cur_action_prob.data))[:5], (50, 50), fontFace = cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color = (255, 255, 255) )

                #add all these annotations to the video or json
                video_tracked.write(np.uint8(frame))

                output_json[int(64*iter_num + frame_no)] = {"actions" : (cur_action, str(float(cur_action_prob.data))[:5]), \
                                                          "detections" : frame_detections}
    
    #calculate how fast our inference was!
    elapsed = timeit.default_timer() - start_time
    print('\n ...inference complete after {} !'.format(elapsed))
    
    #release the video and the json!
    video_tracked.release()
    video_clean.release()

    with open(directory + vid_name[:-4] + "_detections.json", "w") as outfile:
        json.dump(output_json, outfile)
    print("Output video is ", out_video)
    
    return frame_count/fps_video, elapsed, original_dimensions

    

def main():
    args = parse_args()
    segments_df = anns_from_video_ids([args.vid_name[:-4]], args.directory, segment_length=5)
    test_data_loader = get_test_data_loaders(segments_df, batch_size=1, data_dir = args.directory)
    inference_outputs = inference(test_data_loader, args.vid_name, args.tool_model_loc, args.directory)

    print("Returning inference outputs", inference_outputs)

    print("Video was {} seconds, and inference was performed in {} seconds".format(inference_outputs[0], inference_outputs[1]))
    output = [args.vid_name[:-4], inference_outputs[0], inference_outputs[1], \
             float(inference_outputs[2][0]), float(inference_outputs[2][1]), float(inference_outputs[2][2])]
    print("Printing output", output)

    with open(args.vid_name[:-4]+'_times.txt', 'w') as filehandle:
        for listitem in output:
            filehandle.write('%s ' % listitem)

if __name__ == '__main__':
    main()


Starting script!
/home/egoodman/multitaskmodel/scripts
We are using torch version 1.8.0
We are using torchvision version 0.9.0
Studying ['slap']
Studying video at path ./output/slap.mp4
Creating action head
Grabbing dimensions
Original dimensions are [tensor([0]), tensor([0]), tensor([0])]


    Video has following parameters
    Frame Height : 384
    Frame Width : 384
    FPS : 30.0
    Number of Frames : 539
    Duration (seconds) : 17.966666666666665

    
Starting inference on video slap.mp4 at time 58497.554966035
Original dimensions are [tensor([0]), tensor([0]), tensor([0])]
Studying batch 0 torch.Size([64, 3, 384, 384])
Drawing rectangle 78.85747528076172 71.44671630859375 137.85812377929688 127.92196655273438
Drawing rectangle 70.24629211425781 85.19050598144531 131.17945861816406 138.34402465820312
Original dimensions are [tensor([0]), tensor([0]), tensor([0])]
Studying batch 1 torch.Size([64, 3, 384, 384])
Original dimensions are [tensor([0]), tensor([0]), tensor([0])]
Stud