In [None]:
def hex_to_rgb(hex_color):
    try:
        hex_color = hex_color.lstrip('#')
        return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    except:
        return (0, 0, 0)

def euclidean_distance(rgb1, rgb2):
    return sum((a - b) ** 2 for a, b in zip(rgb1, rgb2)) ** 0.5

def average_distance(palette1, palette2):
    rgb_palette1 = [hex_to_rgb(color) for color in palette1]
    rgb_palette2 = [hex_to_rgb(color) for color in palette2]

    total_distance = 0
    count = 0
    for color1 in rgb_palette1:
        for color2 in rgb_palette2:
            total_distance += euclidean_distance(color1, color2)
            count += 1

    return total_distance / count if count > 0 else float('inf')

def palette_similarity_score(palette1, palette2):
    if not palette1 or not palette2:
        return float('inf')  # Return a large number if one of the palettes is empty

    return average_distance(palette1, palette2)

import webcolors
from typing import List

def color_names_to_hex(color_names: List[str]) -> List[str]:
    hex_colors = []
    for color in color_names:
        try:
            hex_value = webcolors.name_to_hex(color)
            hex_colors.append(hex_value)
        except ValueError:
            hex_colors.append(None)  # Append None or a default value if color name is not recognized
    return hex_colors


In [None]:
import json
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import euclidean_distances
import matplotlib.pyplot as plt

with open('/playpen-storage/levlevi/player-re-id/src/testing/ocr_analysis/_50_game_reid_benchmark_/annotations.json') as f:
    annotations = json.load(f)

with open('/playpen-storage/levlevi/player-re-id/src/testing/ocr_analysis/predictions.json') as f:
    predictions = json.load(f)

with open('/playpen-storage/levlevi/player-re-id/src/data/team_rosters_with_ids_and_colors_and_race_and_team_ids.json') as f:
    rosters = json.load(f)

In [148]:
import random
from typing import List
from collections import Counter

count = 0
correct = 0

# 1. narrow down the candidate players

def get_maj_race(races):
    counter = Counter(races)
    return counter.most_common(1)[0][0]

def get_maj_position(positions):
    counter = Counter(positions)
    return counter.most_common(1)[0][0]

def get_maj_jersey_number(jersey_numbers):
    if len(jersey_numbers) == 0:
        return None
    counter = Counter(jersey_numbers)
    return counter.most_common(1)[0][0]

def get_team_colors(team_name):
    return rosters[team_name]['colors']

def get_maj_jersey_colors(color_lists: List[List[str]]):
    # Flatten the list of lists into a single list
    try:
        all_colors = [color for sublist in color_lists for color in sublist]
    except:
        return []
    # Use Counter to count the occurrences of each color
    counter = Counter(all_colors)
    # Get the four most common colors
    most_common_colors = counter.most_common(4)
    # Extract the color strings from the most common tuples
    return [color for color, _ in most_common_colors]

def get_most_similar_team_from_colors(team_names, colors):
    similarty_score = 0
    for team_name in team_names:
        team_colors = get_team_colors(team_name)
        similarity = palette_similarity_score(team_colors, colors)
        if similarity > similarty_score:
            similarty_score = similarity
            most_similar_team = team_name
    return most_similar_team

candidate_players = {}
for subtrack_dir, features in predictions.items():
    t1, t2 = subtrack_dir.split('_')[-3], subtrack_dir.split('_')[-5]
    t1, t2 = t1.replace(" ", "_"), t2.replace(" ", "_")
    t1_players = rosters[t1]['players']
    t2_players = rosters[t2]['players']
    
    for player in t1_players:
        t1_players[player]['team'] = t1
    for player in t2_players:
        t2_players[player]['team'] = t2
    
    combined_dict = {**t1_players, **t2_players}
    candidate_players['players'] = combined_dict
    
    t1_colors = get_team_colors(t1)
    t2_colors = get_team_colors(t2)
    
    # get maj race
    races = [x['race'] for x in features]
    maj_race = get_maj_race(races)
    
    # get maj position
    positions = [x['position'] for x in features]
    maj_position = get_maj_position(positions)
    
    # get maj jersey number
    jersey_numbers = [x['jersey_number'] for x in features if x['jersey_number']]
    maj_jersey_number = get_maj_jersey_number(jersey_numbers)
    
    jersey_colors = [x['jersey_colors'] for x in features]
    maj_jersey_colors = color_names_to_hex(get_maj_jersey_colors(jersey_colors))
    
    # get most similar team
    most_similar_team = get_most_similar_team_from_colors([t1, t2], maj_jersey_colors)
    
    # remove players of different jersey number
    remove_players = []
    for player_name, values in candidate_players['players'].items():
        if values['number'] != maj_jersey_number:
            remove_players.append(player_name)
    for player_name in remove_players:
        del candidate_players['players'][player_name]
        
    # print(f"Number of candidate players: {len(candidate_players['players'])}")
    # remove players of different team
    # print(f"Number of candidate players: {len(candidate_players['players'])}")
    
    # remove_players = []
    # for player_name, values in candidate_players['players'].items():
    #     if values['team'] != most_similar_team:
    #         remove_players.append(player_name)
    # for player_name in remove_players:
    #     del candidate_players['players'][player_name]
        
    # print(f"Number of candidate players: {len(candidate_players['players'])}")
    # remove players of different race
    # remove_players = []
    # for player_name, values in candidate_players['players'].items():
    #     if values['race'] != maj_race:
    #         remove_players.append(player_name)
    # for player_name in remove_players:
    #     del candidate_players['players'][player_name]
    # # print(f"Number of candidate players: {len(candidate_players['players'])}")
    
    # # remove players of different position
    # remove_players = []
    # for player_name, values in candidate_players['players'].items():
    #     print(values['position'], maj_position)
    #     if maj_position not in values['position']:
    #         remove_players.append(player_name)
    # for player_name in remove_players:
    #     del candidate_players['players'][player_name]
    # print(f"Number of candidate players: {len(candidate_players['players'])}")
    
    if len(candidate_players['players']) == 0:
        continue
    else:
        
        # print(candidate_players)
        if len(candidate_players['players']) > 1:
            random_key = random.choice(list(candidate_players['players'].keys()))
            candidate_players = {'players': {random_key: candidate_players['players'][random_key]}}
            
        # print(f"Number candidates: {len(candidate_players['players'])}")
        gt_ids = []
        candidate_ids = [candidate_players['players'][pn]['player_id'] for pn in list(candidate_players['players'].keys())]
        for c, pn in zip(candidate_players['players'], list(candidate_players['players'].keys())):
            team = candidate_players['players'][c]['team']
            player_id = candidate_players['players'][c]['player_id']
            
            # print(subtrack_dir)
            subtrack = subtrack_dir.split("/")[-1]
            key = "/".join(subtrack_dir.split("/")[:-1])
            key = key.split("/")[-1]
            gt_id = annotations[key]['tracks'][subtrack]['human_annotation']
            gt_ids.append(gt_id)
            
        for pred, gt in zip(candidate_ids, gt_ids):
            # print(f"{pred}, {gt}")
            count += 1
            if pred == gt:
                correct += 1

print(f"Total canddidates count: {count}")
print (f"Total correct count: {correct}")
print(f"Accuracy: {correct/50}")

Total canddidates count: 21
Total correct count: 18
Accuracy: 0.36
