In [1]:
from make_dataset import MultiMatchSoccerDataset

In [2]:
dataset = MultiMatchSoccerDataset(data_root="match_data")

Loading Matches: 100%|██████████| 6/6 [00:44<00:00,  7.39s/it]


In [3]:
sample = dataset[156]
sample["graph"]

HeteroData(
  Node={ x=[2875, 10] },
  (Node, attk_and_attk, Node)={
    edge_index=[2, 7782],
    edge_attr=[7782, 1],
  },
  (Node, attk_and_def, Node)={
    edge_index=[2, 13848],
    edge_attr=[13848, 1],
  },
  (Node, def_and_def, Node)={
    edge_index=[2, 3682],
    edge_attr=[3682, 1],
  },
  (Node, attk_and_ball, Node)={
    edge_index=[2, 1362],
    edge_attr=[1362, 1],
  },
  (Node, def_and_ball, Node)={
    edge_index=[2, 1252],
    edge_attr=[1252, 1],
  },
  (Node, temporal, Node)={
    edge_index=[2, 2852],
    edge_attr=[2852, 1],
  }
)

In [4]:
import torch


In [10]:
from multiprocessing import Pool
from tqdm.auto import tqdm
import torch

# 1) 검사 함수 정의
def _check_sample(idx):
    sample = dataset[idx]
    data   = sample["graph"]

    bad = []
    # 1) Node.x 검사
    x = data["Node"].x
    n_nan = torch.isnan(x).sum().item()
    n_inf = torch.isinf(x).sum().item()
    if n_nan or n_inf:
        bad.append(f"Node.x: NaNs={n_nan}, Infs={n_inf}")

    # 2) condition, other, target 검사
    for name in ("condition", "other", "target"):
        t = sample[name]
        n_nan = torch.isnan(t).sum().item()
        n_inf = torch.isinf(t).sum().item()
        if n_nan or n_inf:
            bad.append(f"{name}: NaNs={n_nan}, Infs={n_inf}")

    # 3) 각 엣지 타입별 edge_attr 검사
    for edge_type in data.edge_types:
        store = data[edge_type]
        e = store.edge_attr
        n_nan = torch.isnan(e).sum().item()
        n_inf = torch.isinf(e).sum().item()
        if n_nan or n_inf:
            src, rel, dst = edge_type
            bad.append(f"edge_attr ({src}-{rel}-{dst}): NaNs={n_nan}, Infs={n_inf}")

    return (idx, bad) if bad else None

# 2) 병렬 검사 실행
def parallel_check(start_idx, count, n_workers=4):
    idxs = list(range(start_idx, start_idx + count))
    with Pool(n_workers) as pool:
        results = list(tqdm(pool.imap(_check_sample, idxs),
                            total=len(idxs),
                            desc="Checking samples"))
    bad = [r for r in results if r is not None]
    return bad

# 3) 전체 데이터셋 검사
bad_samples = parallel_check(start_idx=0, count=len(dataset), n_workers=7)

# 4) 결과 출력
if not bad_samples:
    print("모든 샘플에 NaN/Inf 없음")
else:
    for idx, infos in bad_samples:
        print(f"[Sample {idx}] 문제가 있는 필드:")
        for info in infos:
            print("  ", info)


Checking samples:   0%|          | 0/7015 [00:00<?, ?it/s]

모든 샘플에 NaN/Inf 없음


In [None]:
# 1) Node feature 전체 꺼내기
node_feats = sample["graph"]["Node"].x    # shape: [총노드수, 특성차원]

# 2) 보고 싶은 노드의 인덱스(index) 지정
# 예를 들어, 첫 번째 프레임의 0번 공격수 노드를 보고 싶다면
nodes_per_frame = 23   # 공격수11 + 수비수11 + 볼1
frame_idx = 10          # 0번째 프레임
player_idx = 9      # 프레임 내에서 0번 노드
node_idx = frame_idx * nodes_per_frame + player_idx

# 3) 해당 노드만 출력
single = node_feats[node_idx]              # tensor of shape [특성차원]
print(single)


tensor([ 1.2305e-01,  2.2853e-01,  1.7007e-02, -4.4118e-02,  6.7407e+01,
         1.1000e+01,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00])


In [None]:
sample["condition"].shape

torch.Size([125, 202])

In [None]:
sample["condition_columns"][:9]

['Away_22_x',
 'Away_22_y',
 'Away_22_vx',
 'Away_22_vy',
 'Away_22_dist',
 'Away_22_position',
 'Away_22_starter',
 'Away_22_possession_duration',
 'Away_22_neighbor_count']

In [None]:
bases = []
for col in sample["condition_columns"]:
    if col.startswith("ball_"):
        continue
    parts = col.split("_", 2)
    base = "_".join(parts[:2])
    if base not in bases:
        bases.append(base)

In [None]:
bases

['Away_22',
 'Away_25',
 'Away_26',
 'Away_29',
 'Away_31',
 'Away_33',
 'Away_34',
 'Away_35',
 'Away_36',
 'Away_37',
 'Away_39',
 'Home_2',
 'Home_3',
 'Home_5',
 'Home_6',
 'Home_8',
 'Home_9',
 'Home_11',
 'Home_12',
 'Home_13',
 'Home_14',
 'Home_17']