In [90]:
import pandas as pd
from typing import Any, Dict, Optional

df = pd.read_csv("../data/processed/shot_transitions_parsed_charting-m-points-2010s.csv")
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2412 entries, 0 to 2411
Data columns (total 5 columns):
 #   Column               Non-Null Count  Dtype 
---  ------               --------------  ----- 
 0   last_shot_type       2412 non-null   object
 1   last_shot_direction  2412 non-null   object
 2   shot_type            2412 non-null   object
 3   shot_direction       2412 non-null   object
 4   count                2412 non-null   int64 
dtypes: int64(1), object(4)
memory usage: 94.3+ KB


In [69]:
# Regras para remover transições ilegais
# 1. Um "serve" não pode ser seguido de outro serve
# 2. Um "forehand" ou "backhand" não pode ser seguido de um "serve"
# 3. As direções devem estar entre 1 e 3
# 4. Erros não podem ser seguidos por erros
# 5. Um "winner" não pode ser seguido por outro "winner"
# 6. Erros e winners não tem direção definida (usar 'unknown')

possible_stroke_results = ['b', '@', 'f', 'r', '#', 'i', 'm', 'o',
       'winner', 's', 'v', 'z', 'u', 'h', 'l', 'j', 'y', 't', 'k', 'p',
       'q','unknown']

possible_types = ['serve', 'b', '@', 'f', 'r', '#', 'i', 'm', 'o',
       'winner', 's', 'v', 'z', 'u', 'h', 'l', 'j', 'y', 't', 'k', 'p',
       'q','unknown']

erros = ["#", "@"]

possible_directions = ['unknown','1','2','3','0']

In [70]:
from scipy.special import softmax
counts_lookup = {}

# Para cada tipo/direção possível, garantir que exista uma entrada no dicionário
for s_type in possible_types:
    for s_dir in possible_directions:
        # Garantir que todas as combinações de last_s_type e last_s_dir existam
        for last_s_type in possible_types:
            for last_s_dir in possible_directions:
                if last_s_type == "serve" and s_type == "serve":
                    continue  # Regra 1
                if last_s_type not in erros and s_type == "serve":
                    continue  # Regra 2
                if last_s_type in erros + ["winner"]:
                    last_s_dir = 'unknown'  # Corrigir direção ilegal
                if s_type in erros + ["winner"]:
                    s_dir = 'unknown'  # Corrigir direção ilegal

                key = (last_s_type, last_s_dir, s_type, s_dir)
                if key not in counts_lookup:
                    counts_lookup[key] = 0

In [None]:
# Build quick lookup of counts: map (last_type,last_dir,shot_type,shot_dir) -> count
# Replace direction_4,5,6 with 1,2,3 respectively for serves
for _, row in df.iterrows():
    if row["last_shot_type"] == "serve" and row["last_shot_direction"] in ["direction_4", "direction_5", "direction_6"]:
        last_dir = str(int(row["last_shot_direction"].split("_")[1]) - 3)
    else:
        last_dir = row["last_shot_direction"]
    if row["shot_type"] == "serve" and row["shot_direction"] in ["direction_4", "direction_5", "direction_6"]:
        direction = str(int(row["shot_direction"].split("_")[1]) - 3)
    else:
        direction = row["shot_direction"]

    if direction not in ["1", "2", "3", "unknown"] or last_dir not in ["1", "2", "3", "unknown"]:
        continue
    key = (row["last_shot_type"], last_dir, row["shot_type"], direction)
    counts_lookup[key] = int(row["count"])

In [83]:
normal_target_pairs = [(t, d) for t in possible_stroke_results for d in possible_directions]

error_target_pairs = [(t, d) for t in ["serve"] for d in possible_directions]

for idx, pair in enumerate(normal_target_pairs):
    if pair[0] in erros + ["winner"] and pair[1] != 'unknown':
        normal_target_pairs[idx] = (pair[0], 'unknown')

normal_target_pairs = list(set(normal_target_pairs))

In [97]:
import numpy as np

graph: Dict[str, Dict[str, Dict[str, Dict[str, float]]]] = {}
records: list[dict[str, Any]] = []
temperature = 1.0
for src_type in possible_types:
    graph[src_type] = {}
    for src_dir in possible_directions:

        counts = [counts_lookup.get((src_type, src_dir, t, d), 0) for (t, d) in normal_target_pairs]
        if src_type in erros + ["winner"]:
            src_dir = 'unknown'  # Corrigir direção ilegal
            counts = [counts_lookup.get((src_type, src_dir, t, d), 0) for (t, d) in error_target_pairs]

        # gather counts in consistent order
        # apply softmax (with temperature)
        if temperature != 1.0:
            logits = np.array(counts) / float(temperature)
        else:
            logits = np.array(counts, dtype=float)
        probs = softmax(logits) if logits.size > 0 else np.array([])

        graph[src_type][src_dir] = {}
        for (t, d), c, p in zip(normal_target_pairs if src_type not in erros + ["winner"] else error_target_pairs, counts, probs):
            graph[src_type][src_dir][(t,d)] = float(p)

            records.append({
                "last_shot_type": src_type,
                "last_shot_direction": src_dir,
                "shot_type": t,
                "shot_direction": d,
                "count": int(c),
                "probability": float(p)
            })


df_out = pd.DataFrame.from_records(records)
df_out.to_csv("../data/processed/shot_transition_graph_charting-m-points-2010s.csv", index=False)
print("Source:", src_type, src_dir)
print(counts)


Source: unknown 0
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [99]:
graph['f']['1']

{('f', 'unknown'): 0.01020408163265306,
 ('#', 'unknown'): 0.01020408163265306,
 ('h', 'unknown'): 0.01020408163265306,
 ('b', '0'): 0.01020408163265306,
 ('q', 'unknown'): 0.01020408163265306,
 ('l', '3'): 0.01020408163265306,
 ('l', '2'): 0.01020408163265306,
 ('l', '0'): 0.01020408163265306,
 ('k', '1'): 0.01020408163265306,
 ('f', '1'): 0.01020408163265306,
 ('p', '3'): 0.01020408163265306,
 ('t', '1'): 0.01020408163265306,
 ('p', '2'): 0.01020408163265306,
 ('s', '0'): 0.01020408163265306,
 ('b', 'unknown'): 0.01020408163265306,
 ('p', '0'): 0.01020408163265306,
 ('h', '1'): 0.01020408163265306,
 ('j', '3'): 0.01020408163265306,
 ('j', '2'): 0.01020408163265306,
 ('q', '1'): 0.01020408163265306,
 ('l', 'unknown'): 0.01020408163265306,
 ('j', '0'): 0.01020408163265306,
 ('unknown', '3'): 0.01020408163265306,
 ('unknown', '2'): 0.01020408163265306,
 ('i', '0'): 0.01020408163265306,
 ('unknown', '0'): 0.01020408163265306,
 ('o', '0'): 0.01020408163265306,
 ('p', 'unknown'): 0.0102040