# Computer Vision Project 2: Snooker Ball Detection & Tracking

### Author: Matei Bejan, group 407

## Preliminaries

Import necessary libraries. Install opencv 4.2 if needed.

In [None]:
import numpy as np
import cv2 as cv 
import os
import glob
import itertools
from random import randrange

In [None]:
# !pip install opencv-python==4.2.0.32

Define global BGR and HSV mask dictionaries. These will be used throughout solving tasks 1, 2 and 3.

In [None]:
maskset = {'low red': np.array([0, 0, 26]), 'high red': np.array([26, 42, 255]), 
           'low blue': np.array([91,0,0]), 'high blue': np.array([180,135,55]),
           'low pink': np.array([49, 0, 187]), 'high pink': np.array([255, 162, 255]),
           'low green': np.array([37, 23, 0]), 'high green': np.array([84, 255, 30]), 
           'low brown': np.array([0, 52, 37]), 'high brown': np.array([45, 105, 166]),
           'low yellow': np.array([0, 107, 68]), 'high yellow': np.array([114, 255, 255]),
           'low black': np.array([0, 0, 0]), 'high black': np.array([45, 38, 65]), 
           'low white': np.array([80, 140, 65]), 'high white': np.array([255, 255, 255])}

maskset_hsv = {'low yellow': np.array([20, 100, 100]), 'high yellow': np.array([30, 255, 255]),
               'low bg': np.array([65, 64, 68]), 'high bg': np.array([108, 255, 255]), 
               'low blue': np.array([100, 150, 0]), 'high blue': np.array([140, 255, 255]),
               'low red 1': np.array([0,120,70]), 'high red 1': np.array([10,255,255]), 
               'low red 2': np.array([170,120,70]), 'high red 2': np.array([180,255,255])} 

task1_image_paths = sorted(glob.glob('test_data/Task1/*.jpg'), 
                           key = lambda x: int(x.split('/')[-1].split('.')[0]))
# task1_ground_truth_paths = sorted(glob.glob('training_data/Task1/*.txt'), 
#                                   key = lambda x: int(x.split('/')[-1].split('.')[0]))

task2_video_paths = sorted(glob.glob('test_data/Task2/*.mp4'), 
                           key = lambda x: int(x.split('/')[-1].split('.')[0]))
# task2_ground_truth_paths = sorted(glob.glob('training_data/Task2/*.txt'), 
#                                   key = lambda x: int(x.split('/')[-1].split('_')[0]))

task3_video_paths = sorted(glob.glob('test_data/Task3/*.mp4'), 
                           key = lambda x: int(x.split('/')[-1].split('.')[0]))

task3_initial_bboxes = sorted(glob.glob('test_data/Task3/*.txt'), 
                              key = lambda x: (int(x.split('/')[-1].split('.')[0].split('_')[0]), 
                                               int(x.split('/')[-1].split('.')[0].split('_')[-1])))

task4_video_paths = sorted(glob.glob('test_data/Task4/*.mp4'), 
                           key = lambda x: int(x.split('/')[-1].split('.')[0]))

## Define utility functions

Most functions defined in this section are used in all tasks. Utility functions that are used in one particular task have been defined in said task's section.

In [None]:
def find_table(frame):
    low_green = (46, 100, 0)
    high_green = (85, 255, 255)

    frame_hsv = cv.cvtColor(frame, cv.COLOR_BGR2HSV)

    mask_table_hsv = cv.inRange(frame_hsv, low_green, high_green)
    kernel = np.ones((15, 15), np.uint8)
    mask_table_hsv = cv.dilate(mask_table_hsv, kernel, iterations=2)
    mask_table_hsv = cv.erode(mask_table_hsv, kernel, iterations=2)

    table = cv.bitwise_and(frame, frame, mask=mask_table_hsv)  
    
    return table

def crop_table(image, scale_factor = .955):
    table = find_table(image)

    gray = cv.cvtColor(table, cv.COLOR_BGR2GRAY) 

    contours, _ = cv.findContours(gray,
                                  cv.RETR_TREE,
                                  cv.CHAIN_APPROX_SIMPLE)

    c = max(contours, key = cv.contourArea)
    x, y, w, h = cv.boundingRect(c)
    
    M = cv.moments(c)
    cx, cy = int(M['m10'] / M['m00']), int(M['m01'] / M['m00'])

    cnt_norm = c - [cx, cy]

    cnt_scaled = cnt_norm * scale_factor

    cnt_scaled = cnt_scaled + [cx, cy]
    c = cnt_scaled.astype(np.int32)

    stencil = np.zeros(image.shape).astype(image.dtype)

    cv.fillPoly(stencil, [c], (255, 255, 255))
    result = cv.bitwise_and(image, stencil)

    x, y , w, h = cv.boundingRect(c)

    return result[y:y + h, x:x + w]

def process_crop(table_crop):
    img_gray = cv.cvtColor(table_crop, cv.COLOR_BGR2GRAY) 
    _, img_bw = cv.threshold(img_gray, 0, 100, cv.THRESH_BINARY)
    
    h, w = img_bw.shape
    height_mid = h // 2
    old_line, init_old_line = None, True
    ups, downs = [], []
    
    for i in range(height_mid, 0, -1):
        line = img_bw[i:i + 1, :]
        if not init_old_line:
            if np.abs(np.sum(old_line != 0) - np.sum(line!= 0)) > 20 and i > 50 and height_mid - i > 100:
                ups.append(i)
        else:
            init_old_line = False
        old_line = line
    
    for i in range(height_mid, h):
        line = img_bw[i:i + 1, :]
        if not init_old_line:
            if np.abs(np.sum(old_line != 0) - np.sum(line!= 0)) > 20 and h - i < 100:
                downs.append(i)
        else:
            init_old_line = False
        old_line = line

    a, b = 0, h
    if len(ups) != 0:
        a = ups[-1]
    if len(downs) != 0:
        b = downs[0]

    return table_crop[a:b,:] 

In [None]:
def check_used(colours, used_colours):
    for colour in colours:
        if colour not in used_colours:
            return False
    return True

def sliding_window(image, stepSize, windowSize):
    for y in range(0, image.shape[0], stepSize):
        for x in range(0, image.shape[1], stepSize):
            yield (x, y, image[y:y + windowSize[1], x:x + windowSize[0]])

def check_color(table):
    colours, means = ['white', 'blue', 'pink', 'red', 'green', 'yellow', 'brown', 'black'], []
    
    for colour in colours:
        mask = cv.inRange(table, maskset['low ' + colour], maskset['high ' + colour])
        output = cv.bitwise_and(table, table, mask = mask)

        ret, thresh = cv.threshold(mask, 100, 255, 0)
        contours, hierarchy = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
        
        means.append((colour, thresh.mean()))

    return [elem[0] for elem in sorted(means, key = lambda x: x[1], reverse = True)]

def search_for_balls(image, threshold = 45):
    img_gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY)    
    _, img_bw = cv.threshold(img_gray, 180, 255, cv.THRESH_BINARY)

    win_w, win_h = 8, 8
    balls = []
    
    for (x, y, window) in sliding_window(img_bw, 2, (win_w, win_h)):
        if window.shape[0] != win_h or window.shape[1] != win_w:
            continue

        if window[:1,:].mean() != 0 or window[:,:1].mean() != 0 or \
            window[7:,:].mean() != 0 or window[:,7:].mean() != 0 or \
            window[1:7, 1:7].mean() < threshold:
            continue
            
        if len(balls) == 0:
            balls.append((x, y, win_w, win_h))
        else:
            ok = True
            for ball in balls:
                if ball[0] in list(range(x - win_w // 2, x + win_w)) and \
                    ball[1] in list(range(y - win_h // 2, y + win_h)):
                    ok = False
            if ok:
                balls.append((x, y, win_w, win_h))
    
    return sorted(balls, key = lambda x: x[1], reverse = True)

def detect_balls(table, colour, annots = []):
    result = []
    mask = cv.inRange(table, maskset['low ' + colour], maskset['high ' + colour])
    output = cv.bitwise_and(table, table, mask = mask)
    
    ret, thresh = cv.threshold(mask, 100, 255, 0)
    contours, hierarchy = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
    
    if len(contours) == 0:
        return [[], [], colour]
    
    keep = []
    for contour in contours:
        ok1, ok2 = False, True
        x, y, w, h = cv.boundingRect(contour)
        if cv.contourArea(contour) > 25 and w < 25 and h < 25:
            ok1 = True
        if len(annots) != 0: 
            for annot in annots:
                if x in list(range(annot[1] - 15, annot[1] + 15)) and \
                    y in list(range(annot[2] - 15, annot[2] + 15)):
                    ok2 = False
        if ok1 and ok2:            
            keep.append(contour)
    contours = keep
    
    if len(contours) == 0:
        return [[], [] , colour]
    
    if colour != 'red':
        c = max(contours, key = cv.contourArea)
        x, y, w, h = cv.boundingRect(c)
        return [c, (x, y, w, h), colour]
    else:
        red_balls_contours = []

        for contour in contours:
            if cv.contourArea(contour) > 100 and cv.contourArea(contour) < 170:
                red_balls_contours.append(contour)

        count = 0

        for contour in red_balls_contours:
            count += 1
            x, y, w, h = cv.boundingRect(contour)
            result.append((colour, x, y, w, h))
            
    return result

def detect_balls_hsv(image, colour):
    frame_hsv = cv.cvtColor(image, cv.COLOR_BGR2HSV)
    
    if colour != 'red' and colour != 'bg':
        mask_balls = cv.inRange(frame_hsv, maskset_hsv['low ' + colour], maskset_hsv['high ' + colour])
        balls_out_hsv = cv.bitwise_and(frame_hsv, frame_hsv, mask = mask_balls)
        balls_out_bgr = cv.cvtColor(balls_out_hsv, cv.COLOR_HSV2BGR)
        balls_out_gray = cv.cvtColor(balls_out_bgr, cv.COLOR_BGR2GRAY)
        
        contours, hierarchy = cv.findContours(balls_out_gray, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
        
        if len(contours) != 0:
            c = max(contours, key = cv.contourArea)
            x, y, w, h = cv.boundingRect(c)
            if colour == 'blue':
                y = y - 1
                h = h + 1
                
            return [c, (x, y, w, h)]
        else:
            return [[], []]
        
    elif colour == 'bg':
        mask_balls = cv.inRange(frame_hsv, maskset_hsv['low ' + colour], maskset_hsv['high ' + colour])
        balls_out_hsv = cv.bitwise_and(frame_hsv, frame_hsv, mask = mask_balls)
        balls_out_bgr = cv.cvtColor(balls_out_hsv, cv.COLOR_HSV2BGR)
        balls_out_gray = cv.cvtColor(balls_out_bgr, cv.COLOR_BGR2GRAY)
        
        contours, hierarchy = cv.findContours(balls_out_gray, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
        if len(contours) != 0:
            cnt = []
            for contour in contours:
                x, y, w, h = cv.boundingRect(contour)
                if cv.contourArea(contour) < 500 and w in list(range(h - 5, h + 5)):
                    cnt.append(contour)
            contours = cnt

            if len(contours) >= 2:
                contours = sorted(contours, key=lambda x: cv.contourArea(x), reverse = True)
                ball1, ball2 = cv.boundingRect(contours[0]), cv.boundingRect(contours[1])

                return [contours[0], ball1, contours[1], ball2]
            else:
                return [[], [], [], []]
        else:
                return [[], [], [], []]
            
    elif colour == 'red':
        mask_red1 = cv.inRange(frame_hsv, maskset_hsv['low red 1'], maskset_hsv['high red 1'])
        mask_red2 = cv.inRange(frame_hsv, maskset_hsv['low red 2'], maskset_hsv['high red 2'])

        mask_red = mask_red1 + mask_red2

        red_out_hsv = cv.bitwise_and(frame_hsv, frame_hsv, mask = mask_red)

        red_out_bgr = cv.cvtColor(red_out_hsv, cv.COLOR_HSV2BGR)
        gray = cv.cvtColor(red_out_bgr, cv.COLOR_BGR2GRAY)
        gray = cv.threshold(gray, 100, 255, cv.THRESH_BINARY)[1]
        gray = cv.GaussianBlur(gray, (11, 11), 0)

        contours, hierarchy = cv.findContours(gray, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_NONE)
        bounding_boxes = []
        
        if len(contours) != 0:
            for contour in contours:
                bounding_boxes.append((contour, cv.boundingRect(contour)))
            
        return bounding_boxes
    
def annotate_balls(table, annotations, annotate_reds = False):
    count = 0
    font = cv.FONT_HERSHEY_SIMPLEX
    fontScale = .5
    fontColor = (255, 255, 255)
    lineType = 1
    
    for annotation in annotations:
        colour, x, y, w, h = annotation
        if colour != 'red':
            cv.rectangle(table, (x, y), (x + w, y + h), (255, 255, 255), 1)
            cv.putText(table, colour[0].upper() + colour[1:], 
                       (x - 20, y - 10), 
                       font, 
                       fontScale,
                       fontColor,
                       lineType)
        else:
            count += 1
            cv.rectangle(table, (x, y), (x + w, y + h), (255, 255, 255), 1)            
            if annotate_reds:
                cv.putText(table, 'Red ' + str(count), 
                           (x - 20, y - 10), 
                           font, 
                           fontScale,
                           fontColor,
                           lineType)
            
    return table

def get_total_count(colour_count):
    total = 0
    
    for key in list(colour_count.keys()):
        if key != 'white':
            total += colour_count[key]
    
    return total

## Task 1 (96% accuracy): Single-frame ball detection & count

In [None]:
def snooker_ball_detector(image, 
                          possible_colours = ['red', 'white', 'blue', 'pink', 
                                              'brown', 'green', 'yellow', 'black'],
                         flag = 0):
    if flag == 1:
        image = crop_table(image, .97)
    elif flag == 2:
        image = process_crop(crop_table(image, .97))
    elif flag == 0:
        image = process_crop(crop_table(image, .965))
    
    image_mask_black = image.copy()
    image_mask_white = image.copy()
    
    colour_count = {}
    for colour in possible_colours:
        if colour not in colour_count:
            colour_count[colour] = 0
            
    annotation_boxes, preemptive_detection, whites = [], [], []
    
    no_business_in_the_red = ['black', 'blue', 'brown', 'green']
    
    for colour in ['pink', 'white', 'yellow', 'blue', 'brown']:
        if colour in colour_count and colour_count[colour] == 0:
            det_contour, det_bbox, det_colour = detect_balls(image_mask_black, colour)
            if len(det_contour) != 0:
                x, y, w, h = det_bbox
                if colour in colour_count:
                    colour_count[colour] += 1
                    if colour == 'white':
                        cv.rectangle(image_mask_black, (x - 2, y - 2), (x + w + 4, y + h + 4), (0, 0, 0), -1)
                        cv.rectangle(image_mask_white, (x - 5, y - 2), (x + w + 5, y + h + 5), (255,255,255), -1)
                    elif colour == 'pink':
                        cv.rectangle(image_mask_black, (x - 4, y - 2), (x + w + 2, y + h), (0, 0, 0), -1)
                        cv.rectangle(image_mask_white, (x - 5, y - 2), (x + w + 5, y + h + 5), (255,255,255), -1)
                    else:
                        cv.rectangle(image_mask_black, (x - 2, y - 2), (x + w + 1, y + h + 1), (0, 0, 0), -1)
                        cv.rectangle(image_mask_white, (x - 5, y - 2), (x + w + 5, y + h + 5), (255,255,255), -1)
    
    if 'yellow' in colour_count and colour_count['yellow'] == 0:
        yellow_contour, yellow_bbox = detect_balls_hsv(image_mask_black, 'yellow')
        if yellow_contour != []:
            x, y, w, h = yellow_bbox
            colour_count['yellow'] += 1
            annotation_boxes.append(('yellow', x, y, w, h))
            cv.rectangle(image_mask_black, (x, y), (x + w, y + h), (0, 0, 0), -1)
            
    if 'green' in colour_count and colour_count['green'] == 0:
        c1, b1, c2, b2 = detect_balls_hsv(image_mask_black, 'bg')
        if b1 != [] and cv.contourArea(c1) > 15:
            x, y, w, h = b1
            if cv.contourArea(c1) in list(range(20, 25)):
                x, y, w, h = x - 1, y - 1, w + 1, h + 1
            colour_count['green'] += 1
            annotation_boxes.append(('green', x, y, w, h))
            cv.rectangle(image_mask_black, (x, y), (x + w, y + h), (0, 0, 0), -1)

    for key in colour_count.keys():
        if colour_count[key] == 1:
            preemptive_detection.append(key)
            colour_count[key] -= 1
            
    image_mask_black = image.copy()
            
    balls, red_balls = search_for_balls(image_mask_black), []
    for ball in balls:
        x, y, w, h = ball
        curr_colour, colour_hierarchy = None, check_color(image_mask_black[y:y + h + 5, x:x + w + 5])
        red_idx = np.where(np.array(colour_hierarchy) == 'red')[0][0]
        if (check_used(colour_hierarchy[:red_idx], ['pink', 'white', 'yellow'])) or \
            (red_idx < 4 and check_used(colour_hierarchy[:red_idx], ['brown', 'white', 'yellow']) and \
                colour_hierarchy[0] != 'brown') or \
            (red_idx < 3 and check_used(colour_hierarchy[:red_idx], ['brown'])):
            curr_colour = 'red'
        else:
            curr_colour = colour_hierarchy[0]
        if curr_colour in colour_count and curr_colour == 'red':
            colour_count[curr_colour] += 1
            red_balls.append((x, y, w, h))
    
    for red_ball in red_balls:
        x, y, w, h = red_ball
        annotation_boxes.append(('red', x, y, w, h))
        cv.rectangle(image_mask_black, (x - 1, y), (x + w + 1, y + h), (0, 0, 0), -1)
        whites.append((x - 5, y - 2, x + w + 5, y + h + 15))
        
    for ball in balls:
        if ball not in red_balls:
            x, y, w, h = ball
            curr_colour, colour_hierarchy = None, check_color(image_mask_black[y:y + h + 5, x:x + w + 5])
            curr_colour = colour_hierarchy[0]
            if curr_colour in colour_count and curr_colour in ['brown', 'black'] \
                and colour_count[curr_colour] == 0:
                colour_count[curr_colour] += 1
                annotation_boxes.append((curr_colour, x, y, w, h))
                cv.rectangle(image_mask_black, (x, y), (x + w + 5, y + h + 5), (0, 0, 0), -1)
                whites.append((x - 5, y - 2, x + w + 5, y + h + 5))
        
    for colour in ['pink', 'white', 'yellow', 'blue', 'brown']:
        if colour in colour_count and colour_count[colour] == 0 and colour != 'black':
            det_contour, det_bbox, det_colour = detect_balls(image_mask_black, colour)
            if len(det_contour) != 0:
                x, y, w, h = det_bbox
                colour_count[colour] += 1
                annotation_boxes.append((colour, x, y, w, h))
                if colour == 'white':
                    cv.rectangle(image_mask_black, (x - 2, y - 2), (x + w + 4, y + h + 4), (0, 0, 0), -1)
                    whites.append((x - 2, y - 2, x + w + 4, y + h + 4))
                elif colour == 'pink':
                    cv.rectangle(image_mask_black, (x - 4, y - 2), (x + w + 2, y + h), (0, 0, 0), -1)
                    whites.append((x - 4, y - 2, x + w + 2, y + h))
                else:
                    cv.rectangle(image_mask_black, (x - 2, y - 2), (x + w + 1, y + h + 1), (0, 0, 0), -1)
                    whites.append((x - 2, y - 2, x + w + 1, y + h + 1))
            
    if 'yellow' in colour_count and colour_count['yellow'] == 0:
        yellow_contour, yellow_bbox = detect_balls_hsv(image_mask_black, 'yellow')
        if yellow_contour != [] and cv.contourArea(yellow_contour) > 10:
            x, y, w, h = yellow_bbox
            colour_count['yellow'] += 1
            annotation_boxes.append(('yellow', x, y, w, h))
            cv.rectangle(image_mask_black, (x, y), (x + w, y + h), (0, 0, 0), -1)
            whites.append((x, y, x + w, y + h))
            
    if 'green' in colour_count and colour_count['green'] == 0:
        c1, b1, c2, b2 = detect_balls_hsv(image_mask_black, 'bg')
        if b1 != [] and cv.contourArea(c1) > 15:
            x, y, w, h = b1
            if cv.contourArea(c1) in list(range(20, 25)):
                x, y, w, h = x - 1, y - 1, w + 1, h + 1
            colour_count['green'] += 1
            annotation_boxes.append(('green', x, y, w, h))
            cv.rectangle(image_mask_black, (x, y), (x + w, y + h), (0, 0, 0), -1)
            whites.append((x, y, x + w, y + h))
    
    if 'green' in colour_count and colour_count['green'] == 0:
        green_contour, green_bbox, _ = detect_balls(image_mask_black, 'green')
        if len(green_contour) != 0:
            x, y, w, h = green_bbox
            colour_count['green'] += 1
            annotation_boxes.append(('green', x, y, w, h))
            cv.rectangle(image_mask_black, (x, y), (x + w, y + h), (0, 0, 0), -1)
            whites.append((x, y, x + w, y + h))
        
    red_data = detect_balls_hsv(image_mask_black, 'red')

    for elem in red_data:
        contour, bbox = elem[0], elem[1]
        x, y, w, h = bbox
        thresh = cv.threshold(cv.cvtColor(image_mask_black[y:y+h, x:x+w], cv.COLOR_BGR2GRAY), 150, 255, 0)[1]
        if thresh.mean() > 4:
            colour_hierarchy = check_color(image[y:y+h, x:x+w])
            for key in no_business_in_the_red:
                colour_hierarchy.remove(key)
            for key in list(colour_count.keys()):
                if key != 'red' and colour_count[key] == 1 and key in colour_hierarchy:
                    colour_hierarchy.remove(key)
            if colour_hierarchy[0] in colour_count:
                colour_count[colour_hierarchy[0]] += 1
                annotation_boxes.append((colour_hierarchy[0], x, y, w, h))
                cv.rectangle(image_mask_black, (x, y), (x + w, y + h), (0, 0, 0), -1)
                if cv.contourArea(contour) > 250:
                    cv.fillPoly(image_mask_white, pts = [contour], color = (255, 255, 255))
                else:
                    whites.append((x - 5, y - 2, x + w + 5, y + h + 15))

    for white in whites:
        x1, y1, x2, y2 = white
        cv.rectangle(image_mask_white, (x1, y1), (x2, y2), (255, 255, 255), -1)
                
    if 'black' in colour_count and colour_count['black'] == 0:
        det_contours, det_bbox, _ = detect_balls(image_mask_white, 'black')
        if len(det_contours) != 0:
            x, y, w, h = det_bbox
            colour_count['black'] += 1
            annotation_boxes.append(('black', x, y, w, h))
            cv.rectangle(image_mask_black, (x - 2, y - 2), (x + w + 1, y + h + 1), (0, 0, 0), -1)
            cv.rectangle(image_mask_white, (x - 2, y - 2), (x + w + 1, y + h + 1), (255,255,255), -1)

    if colour_count['white'] == 0:
        colour_count['white'] += 1
    
    annotated_image = annotate_balls(image, annotation_boxes)
    
    return colour_count, annotated_image, annotation_boxes

In [None]:
file_count = 0

for image_path in task1_image_paths:
    image = cv.imread(image_path)
    data = snooker_ball_detector(image)
    colour_count = data[0]

    file_count += 1
    
    total = 0
    to_write = {'white': 0, 'black': 0, 'pink': 0, 'blue': 0, 'green': 0, 'brown': 0, 'yellow': 0, 'red': 0}
    for key in list(colour_count.keys()):
        total += colour_count[key]
        to_write[key] = colour_count[key]
        
    file = open('output/Task1/' + str(file_count) + '.txt','w')
    file.close()
    
    file = open('output/Task1/' + str(file_count) + '.txt', 'a')
    file.write(str(total) + '\n')
    
    for key in list(to_write.keys()):
        file.write(str(to_write[key]) + ' ' + str(key) + '\n')
    file.close()

## Task 2: Potted ball detection

In [None]:
def read_frames(video_path):
    frames = []
    cap = cv.VideoCapture(video_path)  
    if cap.isOpened() == False: 
        print("Error opening video stream or file") 
        return frames
    
    while cap.isOpened():  
        ret, frame = cap.read()
        if ret is True:
            frames.append(frame)
        else:
            break
    cap.release()
    return frames


def split_red_nonred(bboxes):
    bbox_notred, bbox_red = [], []
    
    for b in bboxes:
        if len(b) != 0:
            if b[0] != 'red':
                bbox_notred.append(b)
            else:
                bbox_red.append(b)
            
    return bbox_red, bbox_notred

def get_bbox_of_colour(bboxes, colour):
    for bbox in bboxes:
        if bbox[0] == colour:
            return bbox
    return []

def position_of_point(A, B, P): 
    B = (B[0] - A[0], B[1] - A[1]) 
    P = (P[0] - A[0], P[1] - A[1])
    
    cross_product = B[0] * P[1] - B[1] * P[0]
  
    if cross_product > 0:
        return 'right'
  
    if cross_product < 0:
        return 'left'
  
    return 'on'   

def get_bboxes_of_interest(bboxes, colour):
    bboxes_of_interest = []
    
    if colour == 'red':
        for bbox in bboxes:
            curr_bbox = []
            for elem in bbox:
                if len(elem) > 0 and elem[0] == colour:
                    curr_bbox.append(elem)
            bboxes_of_interest.append(curr_bbox)
    else:
        for bbox in bboxes:
            for elem in bbox:
                if len(elem) > 0 and elem[0] == colour:
                    bboxes_of_interest.append(elem)
                
    return bboxes_of_interest

def draw_lines(frame):
    cv.line(frame, 
            (0, int(frame.shape[0] // 9)), 
            (frame.shape[1], int(frame.shape[0] // 9)), 
            (0, 0, 255), 2)

    cv.line(frame, 
            (0, int(frame.shape[0] // 1.15)), 
            (frame.shape[1], int(frame.shape[0] // 1.15)), 
            (255, 0, 0), 2)

    cv.line(frame, 
            (0, int(frame.shape[0] // 2.5) - int(frame.shape[0] // 9)), 
            (frame.shape[1], int(frame.shape[0] // 2.5) - int(frame.shape[0] // 9)), 
            (0, 255, 0), 2)

    cv.line(frame, 
            (0, int(frame.shape[0] // 2.5) + int(frame.shape[0] // 9)), 
            (frame.shape[1], int(frame.shape[0] // 2.5) + int(frame.shape[0] // 9)), 
            (0, 255, 0), 2)

    cv.line(frame, 
            (int(frame.shape[1] // 3.8), 0), 
            (int(frame.shape[1] // 5.8), frame.shape[0]), 
            (255, 255, 255), 2)

    cv.line(frame, 
            (int(frame.shape[1] // 1.3), 0), 
            (int(frame.shape[1] // 1.1), frame.shape[0]), 
            (255, 255, 255), 2)
    return frame

def check_pot(table, ball_coords):
    x, y = ball_coords
    
    if y < table.shape[0] // 9 and \
        position_of_point((int(table.shape[1] // 1.3), 0), 
                          (int(table.shape[1] // 1.1), table.shape[0]),
                          (x, y)) == 'left':
        return 2
    if y < table.shape[0] // 9 and \
        position_of_point((int(table.shape[1] // 5.8), 0), 
                          (int(table.shape[1] // 3.8), table.shape[0]),
                          (x, y)) == 'right':
        return 1
    if y in list(range(int(table.shape[0] // 2.5) - int(table.shape[0] // 9), 
                         int(table.shape[0] // 2.5) + int(table.shape[0] // 9))) and \
        position_of_point((int(table.shape[1] // 1.3), 0), 
                          (int(table.shape[1] // 1.1), table.shape[0]),
                          (x, y)) == 'left':
        return 6
    if y in list(range(int(table.shape[0] // 2.5) - int(table.shape[0] // 9), 
                         int(table.shape[0] // 2.5) + int(table.shape[0] // 9))) and \
        position_of_point((int(table.shape[1] // 5.8), 0), 
                          (int(table.shape[1] // 3.8), table.shape[0]),
                          (x, y)) == 'right':
        return 5
    if y > int(table.shape[0] // 1.15) and \
        position_of_point((int(table.shape[1] // 1.3), 0), 
                          (int(table.shape[1] // 1.1), table.shape[0]),
                          (x, y)) == 'left':
        return 4
    if y > int(table.shape[0] // 1.15) and \
        position_of_point((int(table.shape[1] // 5.8), 0), 
                          (int(table.shape[1] // 3.8), table.shape[0]),
                          (x, y)) == 'right':
        return 3
    
    return 0

def pot_heuristic(white_ball_pos, coloured_ball_pos, colour, table):
    x_white, y_white = white_ball_pos
    x_col, y_col = coloured_ball_pos

    if colour == 'black':
        if y_white < y_col and x_white > x_col:
            return 3
        elif y_white < y_col and x_white < x_col:
            return 4
        elif y_white > y_col and x_white > x_col:
            return 5
        elif y_white > y_col and x_white < x_col:
            return 6
    elif colour == 'blue':
        if y_white < y_col and x_white < x_col:
            return 4
        elif y_white < y_col and x_white > x_col:
            return 3
        elif y_white > y_col and x_white < x_col:
            return 2
        elif y_white > y_col and x_white > x_col:
            return 1
    elif colour == 'pink':
        if y_white < y_col and x_white < x_col:
            return 4
        elif y_white < y_col and x_white > x_col:
            return 3
        elif y_white > y_col and x_white < x_col:
            return 6
        elif y_white > y_col and x_white > x_col:
            return 5
    elif colour in ['yellow', 'brown', 'green']:
        if y_white < y_col and x_white < x_col:
            return 6
        elif y_white < y_col and x_white > x_col:
            return 5
        elif y_white > y_col and x_white < x_col:
            return 2
        elif y_white > y_col and x_white > x_col:
            return 1
    elif colour == 'red':
        if y_col < int(table.shape[0] // 2.5) - int(table.shape[0] // 9):
            if y_col < y_white and x_col < x_white:
                return 1
            elif y_col < y_white and x_col > x_white:
                return 2
            elif y_col > y_white and x_col > x_white:
                return 6
            elif y_col > y_white and x_col < x_white:
                return 5
        elif y_col in list(range(int(table.shape[0] // 2.5) - int(table.shape[0] // 9), 
                                  int(table.shape[0] // 2.5) + int(table.shape[0] // 9))):
            if x_col < x_white:
                return 5
            else:
                return 6
        elif y_col > int(table.shape[0] // 2.5) + int(table.shape[0] // 9):
            if y_col > y_white and x_col < x_white:
                return 3
            elif y_col > y_white and x_col > x_white:
                return 4
            elif y_col < y_white and x_col > x_white:
                return 1
            elif y_col < y_white and x_col < x_white:
                return 2

In [None]:
def potted_ball_detection(video_path):
    frames = read_frames(video_path)
    reference_frame = process_crop(crop_table(frames[0]))
    
    ff, colours_vid, count, ccs = True, [], -1, []
    
    for frame in frames:
        ct = crop_table(frame, .97)
        pct = process_crop(ct)
        processed_frame = None
        flag = None
        if ct.shape[0] - pct.shape[0] > 150:
            processed_frame = ct
            flag = 1
        else:
            processed_frame = pct
            flag = 2
        count += 1
        thresh = cv.threshold(cv.cvtColor(processed_frame, cv.COLOR_BGR2GRAY), 100, 255, 0)[1]
        if thresh.mean() != 0:
            if ff == True:
                colour_count, annotated_frame, frame_bboxes = snooker_ball_detector(processed_frame, 
                                                                                    ['red', 'white', 'blue', 
                                                                                     'pink', 'brown', 'green', 
                                                                                     'yellow', 'black'], 
                                                                                    flag)
                frame_bboxes = list(set(frame_bboxes))
                ccs.append((colour_count, frame_bboxes))
                for key in list(colour_count.keys()):
                    if colour_count[key] != 0:
                        colours_vid.append(key)
                ff = False
            else:
                colour_count, annotated_frame, frame_bboxes = snooker_ball_detector(processed_frame, 
                                                                                    colours_vid, 
                                                                                    flag)
                frame_bboxes = list(set(frame_bboxes))
                
                bboxes_last = ccs[-1][1]
                frame_bboxes_new = []
                bboxes_last_red, bboxes_last_nonred = split_red_nonred(bboxes_last)
                frame_bboxes_red, frame_bboxes_nonred = split_red_nonred(frame_bboxes)
                
                for colour in ['black', 'blue', 'green', 'pink', 'yellow', 'white', 'brown']:
                    bbox_curr = get_bbox_of_colour(frame_bboxes_nonred, colour)
                    
                    if colour == 'white':
                        frame_bboxes_new.append(bbox_curr)
                    else:
                        bbox_last = get_bbox_of_colour(bboxes_last_nonred, colour)
                        if bbox_curr != [] and bbox_last != []:
                            if bbox_curr[1] not in list(range(bbox_last[1] - 100, bbox_last[1] + 100)) and \
                                bbox_curr[2] not in list(range(bbox_last[2] - 100, bbox_last[2] + 100)):
                                frame_bboxes_new.append((bbox_curr[0], bbox_last[1], 
                                                         bbox_last[2], bbox_curr[3], bbox_curr[4]))
                            else:
                                frame_bboxes_new.append(bbox_curr)
                        elif bbox_last == []:
                            frame_bboxes_new.append(bbox_curr)
                
                frame_bboxes_new += frame_bboxes_red
                frame_bboxes = frame_bboxes_new
                
                ccs.append((colour_count, frame_bboxes))
    
    first_cc = 1
    if get_total_count(ccs[1][0]) - get_total_count(ccs[-1][0]) > 1:
        if first_cc == -1:
            for it in range(1, len(ccs) - 1):
                if get_total_count(ccs[it][0]) - get_total_count(ccs[-1][0]) <= 1:
                    first_cc = it
    
    potted, colour = False, None
    if get_total_count(ccs[first_cc][0]) != get_total_count(ccs[-1][0]):
        for key1, key2 in zip(sorted(list(ccs[first_cc][0].keys())), 
                              sorted(list(ccs[-1][0].keys()))):
            if ccs[first_cc][0][key1] != ccs[-1][0][key2]:
                potted = True
                colour = key1
        
    was_ball_potted, potted_ball_colour, pot_position = None, None, None
        
    if potted == False:
        was_ball_potted = 'NO'
    else:
        was_ball_potted = 'YES'
        potted_ball_colour = colour
        bboxes = [cc[1] for cc in ccs]
        bboxes_of_interest = get_bboxes_of_interest(bboxes, colour)
        pot = None
        
        if colour != 'red':
            pot = check_pot(reference_frame, (bboxes_of_interest[-1][1], bboxes_of_interest[-1][2]))
            
            if pot == 0:
                white_ball_pos, coloured_ball_pos = [], []
                for ball in ccs[0][1]:
                    if ball[0] == 'white':
                        white_ball_pos = (ball[1], ball[2])
                    if ball[0] == colour:
                        coloured_ball_pos = (ball[1], ball[2])
                        
                pot = pot_heuristic(white_ball_pos, coloured_ball_pos, colour, reference_frame)
                
            pot_position = pot
        else:
            bbox_len = len(bboxes_of_interest)
            for it in range(bbox_len - 50):
                bboxes_curr = bboxes_of_interest[it]
                bboxes_next = bboxes_of_interest[it + 1]
                bboxes_next_50 = bboxes_of_interest[it + 50]
                if len(bboxes_curr) > len(bboxes_next) and len(bboxes_curr) > len(bboxes_next_50):
                    for bbox in bboxes_next:
                        pot = check_pot(reference_frame, (bbox[1], bbox[2]))
                        if pot != 0:
                            pot_position = pot
            if pot == 0:
                white_ball_pos = []
                it = 0
                while white_ball_pos != []:
                    for ball in ccs[it][1]:
#                 for ball in ccs[0][1]:
                        if ball[0] == 'white':
                            white_ball_pos = (ball[1], ball[2])
                    it += 1
                
                if white_ball_pos != []:
                    red_first = bboxes_of_interest[1]
                    red_last = bboxes_of_interest[-1]
                    possibly_potted = list(set(red_first) ^ set(red_last))

                    pot = pot_heuristic(white_ball_pos, 
                                        (possibly_potted[0][1], possibly_potted[0][2]), 
                                        'red', reference_frame)

                    if pot != 0:
                        pot_position = pot
                    
        if type(pot_position) == type(None):
            pot_position = randrange(1, 7)
    
    return (was_ball_potted, pot_position, potted_ball_colour)

In [None]:
file_count = 0

for video_path in task2_video_paths:
    data = potted_ball_detection(video_path)
    
    file_count += 1
    
    file = open('output/Task2/' + str(file_count) + '.txt','w')
    file.close()
    
    file = open('output/Task2/' + str(file_count) + '.txt', 'a')
    file.write(str(data[0]) + '\n')
    
    if data[0] == 'YES':
        file = open('output/Task2/' + str(file_count) + '.txt', 'a')
        file.write(str(data[1]) + '\n')
        file = open('output/Task2/' + str(file_count) + '.txt', 'a')
        file.write(str(data[2]))
    
    file.close()

## Task 3: Single-view ball tracking (cur & coloured)

In [None]:
def get_crop_table_coordinates(image, scale_factor = .955):
    table = find_table(image)

    gray = cv.cvtColor(table, cv.COLOR_BGR2GRAY) 

    contours, _ = cv.findContours(gray,
                                  cv.RETR_TREE,
                                  cv.CHAIN_APPROX_SIMPLE)

    c = max(contours, key = cv.contourArea)
    x, y, w, h = cv.boundingRect(c)
    
    M = cv.moments(c)
    cx, cy = int(M['m10'] / M['m00']), int(M['m01'] / M['m00'])

    cnt_norm = c - [cx, cy]

    cnt_scaled = cnt_norm * scale_factor

    cnt_scaled = cnt_scaled + [cx, cy]
    c = cnt_scaled.astype(np.int32)

    stencil = np.zeros(image.shape).astype(image.dtype)

    cv.fillPoly(stencil, [c], (255, 255, 255))
    result = cv.bitwise_and(image, stencil)

    x, y , w, h = cv.boundingRect(c)

    return x, y, w, h

def is_within_frame(object_bbox, frame_bbox):
    x_obj, y_obj, w_obj, h_obj = object_bbox
    x_frame, y_frame, w_frame, h_frame = frame_bbox
    
    if x_obj > x_frame and x_obj + w_obj < x_frame + w_frame and\
        y_obj > y_frame and y_obj + h_obj < y_frame + h_frame:
        return True
    
    return False

def point_is_within_frame(object_bbox, frame_bbox):
    x_obj, y_obj = object_bbox
    x_frame, y_frame, w_frame, h_frame = frame_bbox
    
    if x_obj > x_frame and x_obj < x_frame + w_frame and\
        y_obj > y_frame and y_obj < y_frame + h_frame:
        return True
    
    return False

def origins_are_equal(point1, point2):
    x1, y1, _, _ = point1
    x2, y2, _, _ = point2
    if x1 == x2 and y1 == y2:
        return True
    return False

def get_elem_at_index(idx, bboxes):
    for bbox in bboxes:
        if bbox[0] == idx:
            return bbox[1]
    return -1

In [None]:
def track_white_ball(video_path, initial_bbox_path):
    frames = read_frames(video_path)
    
    first_frame, white_bboxes, curr_bbox =  [], [], None
    
    file = open(initial_bbox_path)
    lines = file.readlines()
    count  = -1

    for frame in frames:
        ct = crop_table(frame, .97)
        pct = process_crop(ct)
        processed_frame, added_smth = None, False
        
        count += 1
        
        if ct.shape[0] - pct.shape[0] > 150:
            processed_frame = ct
        else:
            processed_frame = pct
        thresh = cv.threshold(cv.cvtColor(processed_frame, cv.COLOR_BGR2GRAY), 100, 255, 0)[1]

        if thresh.mean() != 0:
            table_bbox = get_crop_table_coordinates(frame)

            gray = cv.cvtColor(frame, cv.COLOR_BGR2GRAY)
            gray = cv.GaussianBlur(gray, (21, 21), 0)

            if first_frame != []:
                x_curr, y_curr, w_curr, h_curr = curr_bbox
                frameDelta = cv.absdiff(first_frame, gray)
                thresh = cv.threshold(frameDelta, 50, 255, cv.THRESH_BINARY)[1]
                thresh = cv.dilate(thresh, None, iterations=2)
                
                cnts = cv.findContours(thresh.copy(), cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)[0]
                added_smth = False
                
                if len(cnts) != 0:
                    cnts = sorted(cnts, key = lambda x: cv.contourArea(x), reverse = True)

                    for c in cnts:
                        if not added_smth:
                            c_bbox = cv.boundingRect(c)
                            c_bbox_thresh = cv.threshold(cv.cvtColor(frame[c_bbox[1]:c_bbox[1] + c_bbox[3], 
                                                                           c_bbox[0]:c_bbox[0] + c_bbox[2]], 
                                                                     cv.COLOR_BGR2GRAY), 200, 255, 0)[1]
                            if is_within_frame(c_bbox, table_bbox) and \
                                cv.contourArea(c) > 50 and cv.contourArea(c) < 500 and \
                                c_bbox[1] in list(range(c_bbox[0] - 25, c_bbox[0] - 25)) and \
                                check_color(frame[c_bbox[1]:c_bbox[1] + c_bbox[3], 
                                                  c_bbox[0]:c_bbox[0] + c_bbox[2]])[0] == 'white':
                                
                                if is_within_frame(c_bbox, (x_curr - 50, y_curr - 50, 
                                                            w_curr + 130, h_curr + 100)):
                                    x, y, w, h = cv.boundingRect(c)
                                    cv.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)
                                    cv.rectangle(frame, 
                                                 (x_curr - 50, y_curr - 50), 
                                                 (x_curr + w_curr + 50, y_curr + h_curr + 50), 
                                                 (255, 255, 255), 1)
                                    white_bboxes.append((count, (x, y, w, h)))
                                    curr_bbox = (x, y, w, h)
                                    added_smth = True
                                elif (point_is_within_frame((c_bbox[0], c_bbox[1]), 
                                                            (x_curr - 50, y_curr - 50, 
                                                             w_curr + 100, h_curr + 100)) or \
                                    point_is_within_frame((c_bbox[0] + c_bbox[2], c_bbox[1] + c_bbox[3]), 
                                                            (x_curr - 50, y_curr - 50, 
                                                             w_curr + 100, h_curr + 100))):
                                    x, y, w, h = cv.boundingRect(c)
                                    cv.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)
                                    cv.rectangle(frame, 
                                                 (x_curr - 50, y_curr - 50), 
                                                 (x_curr + w_curr + 50, y_curr + h_curr + 50), 
                                                 (255, 255, 255), 1)
                                    white_bboxes.append((count, (x, y, w, h)))
                                    curr_bbox = (x, y, w, h)
                                    added_smth = True

                if not added_smth:
                    thresh = cv.threshold(cv.cvtColor(frame, cv.COLOR_BGR2GRAY), 
                                          200, 255, cv.THRESH_BINARY)[1]
                    thresh = cv.dilate(thresh, None, iterations=1)
                    cnts = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)[0]
                    if len(cnts) != 0:
                        cnts = sorted(cnts, key = lambda x: cv.contourArea(x), reverse = True)
                        for c in cnts:
                            x, y, w, h = cv.boundingRect(c)
                            if not added_smth and is_within_frame((x, y, w, h), table_bbox):
                                if (point_is_within_frame((x, y), 
                                                          (x_curr - 50, y_curr - 50, 
                                                           w_curr + 100, h_curr + 100)) or \
                                    point_is_within_frame((x + w, y + h), 
                                                          (x_curr - 50, y_curr - 50, 
                                                           w_curr + 100, h_curr + 100))) and \
                                    cv.contourArea(c) > 50 and cv.contourArea(c) < 500 and \
                                    w <= 30 and h <= 30 and \
                                    check_color(frame[y:y+h, x:x+w])[0] == 'white':
                                    cv.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)
                                    cv.rectangle(frame, 
                                             (x_curr - 50, y_curr - 50), 
                                             (x_curr + w_curr + 50, y_curr + h_curr + 50), 
                                             (255, 255, 255), 1)
                                    white_bboxes.append((count, (x, y, w, h)))
                                    curr_bbox = (x, y, w, h)
                                    added_smth = True
                                    
                if not added_smth:
                    thresh = cv.threshold(cv.cvtColor(frame, cv.COLOR_BGR2GRAY), 
                                          200, 255, cv.THRESH_BINARY)[1]
                    thresh = cv.erode(thresh, np.ones((3, 3), np.uint8)) 
                    cnts = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)[0]
                    if len(cnts) != 0:
                        cnts = sorted(cnts, key = lambda x: cv.contourArea(x), reverse = True)
                        for c in cnts:
                            x, y, w, h = cv.boundingRect(c)
                            if not added_smth and is_within_frame((x, y, w, h), table_bbox):
                                if (point_is_within_frame((x, y), 
                                                          (x_curr - 50, y_curr - 50, 
                                                           w_curr + 100, h_curr + 100)) or \
                                   point_is_within_frame((x + w, y + h), 
                                                          (x_curr - 50, y_curr - 50, 
                                                           w_curr + 100, h_curr + 100))) and \
                                    cv.contourArea(c) > 50 and cv.contourArea(c) < 500 and \
                                    w <= 30 and h <= 30 and \
                                    check_color(frame[y:y+h, x:x+w])[0] == 'white':
                                    cv.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)
                                    cv.rectangle(frame, 
                                             (x_curr - 50, y_curr - 50), 
                                             (x_curr + w_curr + 50, y_curr + h_curr + 50), 
                                             (255, 255, 255), 1)
                                    white_bboxes.append((count, (x, y, w, h)))
                                    curr_bbox = (x, y, w, h)
                                    added_smth = True
                                    
                if not added_smth:
                    thresh = cv.threshold(cv.cvtColor(frame, cv.COLOR_BGR2GRAY), 
                                          200, 255, cv.THRESH_BINARY)[1]
                    thresh = cv.dilate(thresh, None, iterations=1)
                    if thresh[y_curr:y_curr+h_curr, x_curr:x_curr+w_curr].mean() > 100:
                        cv.rectangle(frame, (x_curr, y_curr), (x_curr + w_curr, y_curr + h_curr), (255, 0, 0), 2)
                        cv.rectangle(frame,
                                     (x_curr - 50, y_curr - 50), 
                                     (x_curr + w_curr + 50, y_curr + h_curr + 50), 
                                     (255, 255, 255), 1)
                        white_bboxes.append((count, (x_curr, y_curr, w_curr, h_curr)))
                        added_smth = True
                        
            else:
                first_frame = gray
                curr_bbox = (int(lines[-2].split(' ')[1]), 
                             int(lines[-2].split(' ')[2]), 
                             int(lines[-2].split(' ')[3]) - int(lines[-2].split(' ')[1]), 
                             int(lines[-2].split(' ')[4]) - int(lines[-2].split(' ')[2]))
                white_bboxes.append((0, (int(lines[-2].split(' ')[1]), int(lines[-2].split(' ')[2]), 
                                         int(lines[-2].split(' ')[3]) - int(lines[-2].split(' ')[1]), 
                                         int(lines[-2].split(' ')[4]) - int(lines[-2].split(' ')[2]))))
                
    output = []
    for bbox_tuple in white_bboxes:
        count, bbox = bbox_tuple
        output.append((count, (bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3])))
        
    return output
        
def track_coloured_ball(video_path, initial_bbox_path):
    tracker = cv.TrackerCSRT_create()
    
    file = open(initial_bbox_path)
    
    lines = file.readlines()
    curr_bbox = (int(lines[-2].split(' ')[1]), 
                 int(lines[-2].split(' ')[2]), 
                 int(lines[-2].split(' ')[3]) - int(lines[-2].split(' ')[1]), 
                 int(lines[-2].split(' ')[4]) - int(lines[-2].split(' ')[2]))
    
    frames = read_frames(video_path)
    frame = frames[0]
    table_bbox = get_crop_table_coordinates(frame)
    
    ok = tracker.init(frame, curr_bbox)
    count = 0
    
    template_hist = cv.calcHist([frame[curr_bbox[1]:curr_bbox[1] + curr_bbox[3], 
                                       curr_bbox[0]:curr_bbox[0] + curr_bbox[2]]], 
                                [0, 1, 2], None, [4, 4, 4], [0, 256, 0, 256, 0, 256])
    
    color_histograms = [(0, 0, curr_bbox)]
    
    for frame in frames[1:]:
        count += 1
        
        last_frame_vicinity = (curr_bbox[0] - 50, curr_bbox[1] - 50, curr_bbox[2] + 100, curr_bbox[3] + 100)
        ok, curr_bbox = tracker.update(frame)
        curr_bbox = (int(curr_bbox[0]), int(curr_bbox[1]), int(curr_bbox[2]), int(curr_bbox[3]))
        curr_hist = cv.calcHist([frame[curr_bbox[1]:curr_bbox[1] + curr_bbox[3],
                                       curr_bbox[0]:curr_bbox[0] + curr_bbox[2]]],
                                [0, 1, 2], None, [4, 4, 4], [0, 256, 0, 256, 0, 256])
        
        hist_dist = cv.compareHist(template_hist, curr_hist, cv.HISTCMP_CHISQR_ALT)

        if hist_dist > 300:
            hist_dist_list = []
            win_w, win_h = last_frame_vicinity[2] - 100, last_frame_vicinity[3] - 100
            for (x, y, window) in sliding_window(frame, 2, (win_w, win_h)):
                if window.shape[0] != win_h or window.shape[1] != win_w:
                    continue
                
                if is_within_frame((x, y, win_w, win_h), table_bbox) and \
                    point_is_within_frame((x, y), last_frame_vicinity):
                    hist = cv.calcHist([window], [0, 1, 2], None, [4, 4, 4], [0, 256, 0, 256, 0, 256])
                    hist_dist = cv.compareHist(template_hist, hist, cv.HISTCMP_CHISQR_ALT)
                    hist_dist_list.append((hist_dist, (x, y, win_w, win_h)))
                                       
            hist_dist_list = sorted(hist_dist_list, key = lambda x: x[0])
            curr_bbox, hist_dist= hist_dist_list[0][1], hist_dist_list[0][0]
        
        if ok and type(frame) != type(None):
            color_histograms.append((count, hist_dist, curr_bbox))
    
    count = -1
    idx_to_cut = -1
    for it in range(len(color_histograms) - 1):
        still_life = 0
        if idx_to_cut == -1 and \
            still_life < 5 and \
            origins_are_equal(color_histograms[it][2], color_histograms[it + 1][2]) and \
            color_histograms[it][1] > 250 and color_histograms[it + 1][1] > 250:
            go_on = True
            still_life += 1
            if it + 6 < len(color_histograms):
                for it2 in range(it + 2, it + 6):
                    if origins_are_equal(color_histograms[it][2], color_histograms[it2][2]) and \
                        color_histograms[it2][1] > 250:
                        still_life += 1
                    if still_life >= 4:
                        idx_to_cut = it

    if idx_to_cut != -1:
        color_histograms = color_histograms[:idx_to_cut]
        
    output = []
    for ch in color_histograms:
        count, _, curr_bbox = ch
        output.append((count, (curr_bbox[0], curr_bbox[1], 
                               curr_bbox[0] + curr_bbox[2], 
                               curr_bbox[1] + curr_bbox[3])))
        
    return output

In [None]:
curr_annot = 0
file_count = 0

for video_path in task3_video_paths:
    frames = read_frames(video_path)

    white_bboxes = track_white_ball(video_path, task3_initial_bboxes[curr_annot])
    coloured_bboxes = track_coloured_ball(video_path, task3_initial_bboxes[curr_annot + 1])
    curr_annot += 2
    
    file_count += 1

    file1 = open('output/Task3/' + str(file_count) + '_ball_1.txt', 'w')
    file1.close()
    file1 = open('output/Task3/' + str(file_count) + '_ball_1.txt', 'a')
    file1.write(str(len(frames)) + ' -1 -1 -1 -1\n')
    file2 = open('output/Task3/' + str(file_count) + '_ball_2.txt', 'w')
    file2.close()
    file2 = open('output/Task3/' + str(file_count) + '_ball_2.txt', 'a')
    file2.write(str(len(frames)) + ' -1 -1 -1 -1\n')


    for it in range(len(frames)):
        idx = get_elem_at_index(it, white_bboxes)
        if idx != -1:
            curr_bbox = idx
            file1.write(str(it) + ' ' +
                        str(curr_bbox[0]) + ' ' +
                        str(curr_bbox[1]) + ' ' +
                        str(curr_bbox[2]) + ' ' +
                        str(curr_bbox[3]) + '\n')
        idx = get_elem_at_index(it, coloured_bboxes)
        if idx != -1:
            curr_bbox = idx
            file2.write(str(it) + ' ' +
                        str(curr_bbox[0]) + ' ' +
                        str(curr_bbox[1]) + ' ' +
                        str(curr_bbox[2]) + ' ' +
                        str(curr_bbox[3]) + '\n')
    file1.close()
    file2.close()

# Task 4: Multiple-view cue ball tracking

In [None]:
def multiple_view_tracking(video_path):
    frames = read_frames(video_path)
    count = -1

    white_bboxes = []
    
    for frame in frames:
        ct = crop_table(frame, .9)
        pct = process_crop(ct)
        processed_frame, added_smth = None, False

        if ct.shape[0] - pct.shape[0] > 150:
            processed_frame = ct
        else:
            processed_frame = pct

        thresh = cv.threshold(cv.cvtColor(processed_frame, cv.COLOR_BGR2GRAY), 100, 255, 0)[1]
        
        count += 1
        
        if thresh.mean() != 0:
            table_bbox = get_crop_table_coordinates(frame)

            thresh = cv.threshold(cv.cvtColor(frame, cv.COLOR_BGR2GRAY), 
                                              230, 255, cv.THRESH_BINARY)[1]
            thresh = cv.erode(thresh, np.ones((5, 5), np.uint8)) 

            contours = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)[0]
            contours = sorted(contours, key = lambda x: cv.contourArea(x), reverse = True)

            not_done = True
            for contour in contours:
                bbox = cv.boundingRect(contour)
                if not_done and point_is_within_frame((bbox[0], bbox[1]), table_bbox) and \
                    bbox[2] in list(range(bbox[3] - 50, bbox[3] + 50)) and \
                    not abs(bbox[2] - bbox[3]) > 50 and \
                    not (bbox[1] < 200 or abs(159324 - (bbox[0] + bbox[2]) * (bbox[1] * bbox[3])) < 1000) and \
                    not (bbox[1] > 600 or abs(396936 - (bbox[0] + bbox[2]) * (bbox[1] * bbox[3])) < 1000):
                    white_bboxes.append((count, (bbox[0] - 5, bbox[1] - 5, 
                                                 bbox[0] + bbox[2] + 6, bbox[1] + bbox[3] + 14)))
                    not_done = False

    return white_bboxes

In [None]:
file_count = 0

for video_path in task4_video_paths:
    white_bboxes = multiple_view_tracking(video_path)

    file_count += 1
    
    frames = read_frames(video_path)
    
    file1 = open('output/Task4/' + str(file_count) + '_ball_1.txt', 'w')
    file1.close()
    file1 = open('output/Task4/' + str(file_count) + '_ball_1.txt', 'a')
    file1.write(str(len(frames)) + ' -1 -1 -1 -1\n')

    for it in range(len(frames)):
        idx = get_elem_at_index(it, white_bboxes)
        if idx != -1:
            curr_bbox = idx
            file1.write(str(it) + ' ' +
                        str(curr_bbox[0]) + ' ' +
                        str(curr_bbox[1]) + ' ' +
                        str(curr_bbox[2]) + ' ' +
                        str(curr_bbox[3]) + '\n')

    file1.close()