In [None]:
import torch
import torch.nn as nn
import cv2
import os
import argparse
import math
import sys  
sys.path.insert(0, os.path.dirname(os.getcwd()))
sys.path.insert(0, os.path.join(os.path.dirname(os.getcwd()), "network"))
sys.path.insert(0, os.path.join(os.path.dirname(os.getcwd()), "network", "data"))


from tqdm.notebook import tqdm
import network.fpn as fpn
import network.nms as nms
from network import footandball as footandball
from data import augmentation as augmentations

In [None]:
ROOT_DIR = os.path.dirname(os.getcwd())
DATA_FOLDER = os.path.join(ROOT_DIR, "data")

In [None]:
BALL_LABEL = 1
PLAYER_LABEL = 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
WEIGHT_FILE = os.path.join(DATA_FOLDER, "model_20201019_1416_final.pth")

In [None]:
video_waltter_path = os.path.join(DATA_FOLDER, "example_waltter_synchronized.mov")
video_vikture_path = os.path.join(DATA_FOLDER, "example_vikture_late_15s_synchronized.mov")

In [None]:
video_left_capture = cv2.VideoCapture(video_vikture_path)
video_right_capture = cv2.VideoCapture(video_waltter_path)

In [None]:
left_n_frames = int(video_left_capture.get(cv2.CAP_PROP_FRAME_COUNT))
right_n_frames = int(video_right_capture.get(cv2.CAP_PROP_FRAME_COUNT))

print(left_n_frames)
print(right_n_frames)

#total_frames = min(left_n_frames, right_n_frames)
total_frames = 600

print(total_frames)

In [None]:
final_fps = 60.0
final_height = 1080
final_width = 1920
fourcc = cv2.VideoWriter_fourcc('M','J','P','G')

In [None]:
video_path = os.path.join(DATA_FOLDER, "example_human_detection_video_4.avi")
video_output = cv2.VideoWriter(video_path, fourcc, final_fps, (final_width , final_height))

In [None]:
def build_model():
    phase='detect'
    max_player_detections=100
    max_ball_detections=100
    player_threshold=0.7
    ball_threshold=0.7
    
    layers, out_channels = fpn.make_modules(fpn.cfg['X'], batch_norm=True)
    lateral_channels = 32
    i_channels = 32

    base_net = fpn.FPN(layers, out_channels=out_channels, lateral_channels=lateral_channels, return_layers=[1, 3])
    ball_classifier = nn.Sequential(nn.Conv2d(lateral_channels, out_channels=i_channels, kernel_size=3, padding=1),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(i_channels, out_channels=2, kernel_size=3, padding=1))
    player_classifier = nn.Sequential(nn.Conv2d(lateral_channels, out_channels=i_channels, kernel_size=3, padding=1),
                                      nn.ReLU(inplace=True),
                                      nn.Conv2d(i_channels, out_channels=2, kernel_size=3, padding=1))
    player_regressor = nn.Sequential(nn.Conv2d(lateral_channels, out_channels=i_channels, kernel_size=3, padding=1),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(i_channels, out_channels=4, kernel_size=3, padding=1))
    detector = footandball.FootAndBall(phase, base_net, player_regressor=player_regressor, player_classifier=player_classifier,
                           ball_classifier=ball_classifier, ball_threshold=ball_threshold,
                           player_threshold=player_threshold, max_ball_detections=max_ball_detections,
                           max_player_detections=max_player_detections)
    return detector

In [None]:
model = build_model()
model = model.to(device)
state_dict = torch.load(WEIGHT_FILE)

model.load_state_dict(state_dict)
# Set model to evaluation mode
model.eval()

In [None]:
def draw_bboxes(image, detections):
    font = cv2.FONT_HERSHEY_SIMPLEX
    for box, label, score in zip(detections['boxes'], detections['labels'], detections['scores']):
        if label == PLAYER_LABEL:
            x1, y1, x2, y2 = box
            color = (255, 0, 0)
            cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
            cv2.putText(image, '{:0.2f}'.format(score), (int(x1), max(0, int(y1)-10)), font, 1, color, 2)

        elif label == BALL_LABEL:
            x1, y1, x2, y2 = box
            x = int((x1 + x2) / 2)
            y = int((y1 + y2) / 2)
            color = (0, 0, 255)
            radius = 25
            cv2.circle(image, (int(x), int(y)), radius, color, 2)
            cv2.putText(image, '{:0.2f}'.format(score), (max(0, int(x - radius)), max(0, (y - radius - 10))), font, 1,
                        color, 2)

    return image

In [None]:
def equalize_histogram(rgb_image):
    r_image, g_image, b_image = cv2.split(rgb_image)

    r_image_eq = cv2.equalizeHist(r_image)
    g_image_eq = cv2.equalizeHist(g_image)
    b_image_eq = cv2.equalizeHist(b_image)

    image_eq = cv2.merge([r_image_eq, g_image_eq, b_image_eq])
    return image_eq

In [None]:
def preprocess_image(image):
    image = equalize_histogram(image)
    
    return image

def preprocess_images(images):
    preprocessed_images = []
    
    for image in images:
        preprocessed_image = preprocess_image(image)
        preprocessed_images.append(preprocessed_image)
        
    return preprocessed_images

In [None]:
def run_detection(images):

    detection_list = []
    annotated_frames = []
    
    image_tensors = []
    
    for frame in images:
        # Convert color space from BGR to RGB, convert to tensor and normalize
        img_tensor = augmentations.numpy2tensor(frame)
        image_tensors.append(img_tensor)
        
    with torch.no_grad():
        # Add dimension for the batch size
        img_tensor = img_tensor.unsqueeze(dim=0).to(device)
        detections = model(img_tensor)[0]

        n_humans = len(detections['labels'])

        detection_list.append((n_humans, detections))

        frame = draw_bboxes(frame, detections)
        annotated_frames.append(frame)
    
    return detection_list, annotated_frames

In [None]:
def write_frames(output_handle, frames):
    for frame in frames:
        output_handle.write(frame)

In [None]:
def player_detection_video():
    
    captured_frames = []

    optical_flow_window_length = int(math.floor(final_fps / 2))
    n_windows = math.floor(total_frames/optical_flow_window_length)

    for i in tqdm(range(n_windows)):

        
        left_frames = []
        right_frames = []

        for j in range(optical_flow_window_length):
            frame_number = i*optical_flow_window_length + j
            video_left_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
            res, frame = video_left_capture.read()
            if res:
                left_frames.append(frame)
            else:
                print("Error reading frame")

        for j in range(optical_flow_window_length):
            frame_number = i*optical_flow_window_length + j
            video_right_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
            res, frame = video_right_capture.read()
            if res:
                right_frames.append(frame)
            else:
                print("Error reading frame")
        
        
        left_detections, left_annotated_frames = run_detection(left_frames)
        right_detections, right_annotated_frames = run_detection(right_frames)

        left_humans_count = sum([i[0] for i in left_detections])
        right_humans_count = sum([i[0] for i in right_detections])

        print(f"Left humans: {left_humans_count}")
        print(f"Right humans: {right_humans_count}")
        
        if left_humans_count > right_humans_count:
            images_processed = preprocess_images(left_annotated_frames)
        else:
            images_processed = preprocess_images(right_annotated_frames)

        write_frames(video_output, images_processed)
        
        
    video_left_capture.release()
    video_right_capture.release()
    video_output.release()

In [None]:
player_detection_video()