In [1]:
from src.simulation.group_phase import simulate_group_phase
from src.simulation.bracket import get_knockout_games
from src.model import GoalSampler
import pandas as pd
from collections import defaultdict
from tqdm import tqdm

group_games = pd.read_csv(r"data\cleaned\2024_games.csv", sep=";")
group_games = group_games[group_games["group"].notna()]
market_values = pd.read_csv(r"data\cleaned\2024_market_values.csv", sep=";", index_col="Country")["MarketValue"]
goal_sampler = GoalSampler()
stages = ["Round of 16", "Quarterfinals", "Semifinals", "Final"]
results = {stage: defaultdict(int) for stage in stages}
results["Winner"] = defaultdict(int)


def simulate_tournament(group_games: pd.DataFrame, market_values: pd.Series) -> str:
    global results
    group_results = simulate_group_phase(group_games)
    next_games = get_knockout_games(group_results)

    for stage in stages:
        for team1, team2 in next_games:
            results[stage][team1] += 1
            results[stage][team2] += 1

        round_winners = [
            goal_sampler.get_knockout_stage_winner(team1, team2, market_values[team1], market_values[team2]) for team1, team2 in next_games
        ]
        if len(round_winners) == 1:
            results["Winner"][round_winners[0]] += 1
            return
        next_games = [round_winners[i:i + 2] for i in range(0, len(round_winners), 2)]

In [2]:
n_runs = 100_000
pbar = tqdm(total=n_runs)

for simulation_run in range(n_runs):
    simulate_tournament(group_games, market_values)
    pbar.update(1)

100%|█████████▉| 999/1000 [00:29<00:00, 33.46it/s]

In [3]:
results["Winner"]

defaultdict(int,
            {'Netherlands': 80,
             'Portugal': 107,
             'England': 271,
             'France': 170,
             'Belgium': 40,
             'Ukraine': 19,
             'Germany': 65,
             'Spain': 80,
             'Turkey': 16,
             'Italy': 52,
             'Scotland': 8,
             'Slovakia': 10,
             'Georgia': 4,
             'Switzerland': 5,
             'Czech Republic': 8,
             'Hungary': 4,
             'Denmark': 22,
             'Serbia': 12,
             'Croatia': 9,
             'Romania': 5,
             'Slovenia': 4,
             'Austria': 2,
             'Poland': 4,
             'Albania': 3})

100%|██████████| 1000/1000 [00:40<00:00, 33.46it/s]