In [2]:
from modules.data_utils import load_all_slices_from_tree, SliceDataset2p5DMasked, _compute_runs
from sklearn.model_selection import GroupShuffleSplit
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset

def split_subjectwise_indices_ordered(subject_id, test_ratio=0.3, val_ratio=0.1, seed=42):
    rng = np.random.RandomState(seed)
    uniq_sid, first_pos = np.unique(subject_id, return_index=True)
    sid_in_order = uniq_sid[np.argsort(first_pos)]
    sid_shuffled = sid_in_order.copy(); rng.shuffle(sid_shuffled)

    n = len(sid_shuffled)
    n_test = int(round(n*test_ratio))
    n_val  = int(round((n - n_test)*val_ratio))

    test_s  = set(sid_shuffled[:n_test])
    remain  = sid_shuffled[n_test:]
    val_s   = set(remain[:n_val])
    train_s = set(remain[n_val:])

    idx_all = np.arange(len(subject_id))
    def expand(sids):
        mask = np.isin(subject_id, list(sids))
        return idx_all[mask]  # ✅ 원본 전역 인덱스 오름차순

    return expand(train_s), expand(val_s), expand(test_s)

def apply_indices(X, y, sep_mask, idx):
    return X[idx], y[idx], sep_mask[idx]

def make_valid_centers(X, y, sep_mask, k=1, pad_mode="edge", drop_mixed=False):
    """subject 경계/라벨 혼합 규칙을 만족하는 유효 중심 인덱스 배열 생성."""
    # runs: [(start,end), ...], run_id[i]: i가 속한 run 번호
    runs, run_id = _compute_runs(sep_mask)
    centers = []
    for i in range(len(X)):
        rid = run_id[i]
        s, e = runs[rid]
        idxs = np.arange(i-k, i+k+1)

        # 경계 처리
        if pad_mode == "edge":
            idxs = np.clip(idxs, s, e)
            valid_mask = np.ones_like(idxs, dtype=bool)
        else:  # zero padding
            valid_mask = (idxs >= s) & (idxs <= e)

        if drop_mixed:
            # zero padding이면 경계 밖은 제외하고 비교
            cand = idxs[valid_mask]
            if not np.all(y[cand] == y[i]):
                continue

        centers.append(i)

    return np.array(centers, dtype=np.int64)

def undersample_on_centers(centers, y, neg_ratio=1.0, seed=42):
    """센터 인덱스 배열에 대해 라벨 기준 언더샘플링(재현성 보장)."""
    rng = np.random.RandomState(seed)
    yc = y[centers]
    pos_c = centers[yc == 1]
    neg_c = centers[yc == 0]

    P = len(pos_c)
    keep_neg = min(int(neg_ratio * P), len(neg_c))
    if P == 0 or keep_neg <= 0:
        return centers  # 그대로 반환(혹은 예외 처리)

    sel_neg = rng.choice(neg_c, size=keep_neg, replace=False)
    keep = np.concatenate([pos_c, sel_neg])
    rng.shuffle(keep)
    return keep


# -------------------------------
# 헬퍼: 선택한 centers(원본 인덱스 기준)를 Dataset 내부 인덱스로 매핑
# - SliceDataset2p5DMasked는 __init__에서 self.centers(원본 인덱스) 배열을 만든다.
# - Subset은 "dataset 내부 인덱스"를 받으므로, 우리가 고른 centers(원본 인덱스)를
#   dataset.centers에서의 위치로 변환해야 한다.
# -------------------------------
def map_centers_to_subset_idx(dataset: SliceDataset2p5DMasked, selected_centers: np.ndarray):
    """
    dataset.centers (원본 인덱스 배열) 안에서 selected_centers(원본 인덱스)의 위치를 찾아
    Subset 용 인덱스 리스트로 변환.
    """
    # 빠른 매핑을 위해 원본 index -> dataset 내부 위치 dict 생성
    pos_in_ds = {int(c): i for i, c in enumerate(dataset.centers.tolist())}
    subset_idx = [pos_in_ds[int(c)] for c in selected_centers if int(c) in pos_in_ds]
    return subset_idx



In [13]:
X, y, subject_id = load_all_slices_from_tree(root_dir='./01_data/04_flair_preproc_slices/', 
                                             select_N=100,
                                             choose_major_slice=True)

[Loaded] subjects=100, total_slices=317, pos=22, neg=295


In [None]:
import torch

# 하이퍼 파라미터
seed = 42
test_ratio = 0.3
val_ratio = 0.1
k = 1                       # 2.5D 문맥 크기 (i±k → 총 2k+1 채널)
pad_mode = "edge"           # 경계 처리: "edge" 또는 "zero"
drop_mixed = False          # True면 i±k 윈도우의 라벨이 모두 중심과 같을 때만 사용
neg_ratio = 1.0             # 언더샘플링 비율(양성 1 당 음성 R)

batch_size = 16
num_workers = 0
device = "cuda" if torch.cuda.is_available() else "cpu"

X, y, subject_id = load_all_slices_from_tree(root_dir='./01_data/04_flair_preproc_slices/', 
                                             select_N=100,
                                             choose_major_slice=True)

# 1) subject run/그룹 id 생성
# runs, group_id = _compute_runs(subject_id)

# 2) subject-level split (그룹 단위로 분할 → 누수 방지)
train_idx, val_idx, test_idx = split_subjectwise_indices_ordered(
    subject_id, test_ratio=test_ratio, val_ratio=val_ratio, seed=seed
)

# 3) 각 세트에서 유효 center 인덱스 생성 (subject 경계/라벨혼합 정책 반영)
tr_centers_all = make_valid_centers(X[train_idx], y[train_idx], subject_id[train_idx],
                                    k=k, pad_mode=pad_mode, drop_mixed=drop_mixed)
va_centers_all = make_valid_centers(X[val_idx],   y[val_idx],   subject_id[val_idx],
                                    k=k, pad_mode=pad_mode, drop_mixed=drop_mixed)
te_centers_all = make_valid_centers(X[test_idx],  y[test_idx],  subject_id[test_idx],
                                    k=k, pad_mode=pad_mode, drop_mixed=drop_mixed)

# 4) (중요) 언더샘플링은 train center에만 적용 (재현성을 위해 seed 고정)
tr_centers_bal = undersample_on_centers(tr_centers_all, y[train_idx], neg_ratio=neg_ratio, seed=seed)

# 5) Dataset 구성
#    주의: 여기서 Dataset은 "split된 배열"을 받음 (원본에서 인덱싱)
ds_train_full = SliceDataset2p5DMasked(
    X=X[train_idx], y=y[train_idx], sep_mask=subject_id[train_idx],
    k=k, drop_mixed=drop_mixed, pad_mode=pad_mode
)
ds_val = SliceDataset2p5DMasked(
    X=X[val_idx], y=y[val_idx], sep_mask=subject_id[val_idx],
    k=k, drop_mixed=drop_mixed, pad_mode=pad_mode
)
ds_test = SliceDataset2p5DMasked(
    X=X[test_idx], y=y[test_idx], sep_mask=subject_id[test_idx],
    k=k, drop_mixed=drop_mixed, pad_mode=pad_mode
)

# 6) 우리가 고른 train centers만 사용하도록 Subset 구성
#    - ds_train_full.centers: (유효 center의 "원본 인덱스") 배열
#    - tr_centers_bal: 우리가 undersampling으로 고른 "원본 인덱스"
subset_idx = map_centers_to_subset_idx(ds_train_full, tr_centers_bal)
ds_train = Subset(ds_train_full, subset_idx)

# (선택) 디버그 출력
print(f"[Centers] train(full)={len(ds_train_full)} | train(bal)={len(ds_train)} "
        f"| val={len(ds_val)} | test={len(ds_test)}")

[Loaded] subjects=100, total_slices=317, pos=22, neg=295
[Centers] train(full)=199 | train(bal)=26 | val=21 | test=97


In [22]:
# 예: train 셋 몇 개 샘플을 꺼내 윈도우의 실제 인덱스와 라벨 확인
from torch.utils.data import DataLoader

dl_debug = DataLoader(ds_train, batch_size=1, shuffle=False)
for n, batch in enumerate(dl_debug):
    x, y = batch           # x: (B=1, C=2k+1, H, W), y: (B,)
    # ds_train은 Subset이므로, 원본 center는 used_centers[n]로 추적 가능
    center_global = tr_centers_bal[n]
    print(f"[sample {n}] center_global={center_global}, label={y.item()}")
    if n == 5:
        break

[sample 0] center_global=15, label=0.0
[sample 1] center_global=84, label=0.0
[sample 2] center_global=165, label=0.0
[sample 3] center_global=74, label=1.0
[sample 4] center_global=122, label=1.0
[sample 5] center_global=72, label=1.0


In [16]:
y[train_idx]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0])

In [17]:
np.where(y[train_idx]==1)

(array([ 46,  47,  48,  72,  73,  74,  75,  76, 118, 119, 120, 121, 122]),)

In [10]:
print(subset_idx)

[15, 84, 165, 74, 122, 72, 120, 182, 119, 108, 76, 126, 73, 46, 48, 47, 75, 169, 35, 109, 190, 118, 132, 177, 24, 121]


In [21]:
print(y[train_idx][subset_idx])

[0 0 0 1 1 1 1 0 1 0 1 0 1 1 1 1 1 0 0 0 0 1 0 0 0 1]


In [18]:
ds_train[13][1]

tensor(1.)

In [20]:
tr_centers_all

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
       156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
       169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 18

In [19]:
tr_centers_bal

array([ 15,  84, 165,  74, 122,  72, 120, 182, 119, 108,  76, 126,  73,
        46,  48,  47,  75, 169,  35, 109, 190, 118, 132, 177,  24, 121])

In [4]:
# 0) 일관성 체크
assert X.shape[0] == y.shape[0] == subject_id.shape[0]
assert subject_id.ndim == 1 and np.issubdtype(subject_id.dtype, np.integer)

# 1) split 후 subject 경계(run) 확인 (처음 50개만 보기)
print("train subject_id head:", subject_id[train_idx][:50])
print("val   subject_id head:", subject_id[val_idx][:50])
print("test  subject_id head:", subject_id[test_idx][:50])

# 2) center 인덱스 체계 확인 (train split local)
# ds_train_full.centers 가 split-local 인덱스라면 다음이 참이어야 함
assert np.all((np.asarray(ds_train_full.centers) >= 0) &
              (np.asarray(ds_train_full.centers) < len(ds_train_full)))

# 3) undersampling 전/후 라벨 비율 로그
y_tr_local = y[train_idx]
pos_all = y_tr_local[tr_centers_all].sum()
neg_all = len(tr_centers_all) - pos_all

pos_bal = y_tr_local[tr_centers_bal].sum()
neg_bal = len(tr_centers_bal) - pos_bal

print(f"[Train centers] total={len(tr_centers_all)} (pos={pos_all}, neg={neg_all})")
print(f"[Train centers (balanced)] total={len(tr_centers_bal)} (pos={pos_bal}, neg={neg_bal})")

# print(subject_id)

# print(subject_id[train_idx])
# print(subject_id[val_idx])
# print(subject_id[test_idx])

print(tr_centers_bal)
print(subset_idx)

print(ds_train.dataset.y)
print(ds_train.dataset.centers)
print(ds_train.indices)

train subject_id head: [ 1  1  1  2  2  2  3  3  3  6  6  6  7  7  7  8  8  8 13 13 13 14 14 14
 16 16 16 17 17 17 19 19 19 20 20 20 21 21 21 23 23 23 24 24 24 25 25 25
 25 25]
val   subject_id head: [ 5  5  5 11 11 11 28 28 28 47 47 47 66 66 66 85 85 85 93 93 93]
test  subject_id head: [ 0  0  0  4  4  4  9  9  9 10 10 10 12 12 12 15 15 15 18 18 18 22 22 22
 26 26 26 30 30 30 31 31 31 33 33 33 39 39 39 39 39 39 39 40 40 40 42 42
 42 44]
[Train centers] total=199 (pos=13, neg=186)
[Train centers (balanced)] total=26 (pos=13, neg=13)
[ 15  84 165  74 122  72 120 182 119 108  76 126  73  46  48  47  75 169
  35 109 190 118 132 177  24 121]
[15, 84, 165, 74, 122, 72, 120, 182, 119, 108, 76, 126, 73, 46, 48, 47, 75, 169, 35, 109, 190, 118, 132, 177, 24, 121]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.
 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 1. 1. 1. 1. 1. 

In [None]:
subset_idx

In [None]:
tr_centers_bal

In [None]:
print(train_idx.shape)
print(y[train_idx])
print(tr_centers.shape)

print(train_idx)
print(tr_centers)
print(tr_centers_bal)
print(y[train_idx][tr_centers_bal])

In [None]:
from modules.data_utils import undersample_negatives
Xtr_bal, ytr_bal = undersample_negatives(Xtr, ytr, neg_ratio=1.0)

In [None]:
ytr_bal

In [None]:
print(group_id)
print(sep_va)

In [None]:
print(train_idx.shape)
print(val_idx.shape)
print(test_idx.shape)

In [None]:
import numpy as np
np.argwhere(y==1)

In [None]:
y[[75, 76, 77, 78, 79]]

In [None]:
sep_mask[[75, 76, 77, 78, 79]]

In [None]:
print(X.shape)
print(y.shape)
print(sep_mask.shape)

In [None]:
ds_train = SliceDataset2p5DMasked(X, y, sep_mask, k=1, drop_mixed=False, pad_mode='edge')

In [None]:
xx, yy = ds_train[75]
print(xx[0][300, 300:310])
print(xx[1][300, 300:310])
print(xx[2][300, 300:310])

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms

def _compute_runs(sep_mask: np.ndarray):
    """
    sep_mask가 일정하게 유지되는 구간(= 같은 subject 구간)을 (start, end)로 반환.
    sep_mask가 0/1로 번갈아 들어있어도, 값이 바뀌는 지점이 경계가 된다.
    
    - ex) sep_mask = array([0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1])
    - ex) runs = [(0, 2), (3, 5), (6, 8), (9, 11)]
    - ex) run_id = array([ 0,  0,  0,  1]
    """
    
    sep_mask = sep_mask.astype(np.int64)
    if len(sep_mask) == 0:
        return []
    # 경계 지점 찾기: 값이 바뀌는 인덱스
    boundaries = np.where(np.diff(sep_mask) != 0)[0]
    starts = np.r_[0, boundaries + 1]
    ends   = np.r_[boundaries, len(sep_mask) - 1]
    runs = list(zip(starts, ends))  # 각 (start, end) 포함 구간
    # 각 인덱스가 어느 run에 속하는지 맵핑 배열도 같이 반환하면 효율적
    run_id = np.empty(len(sep_mask), dtype=np.int64)
    for rid, (s, e) in enumerate(runs):
        run_id[s:e+1] = rid
    return runs, run_id

runs, run_id = _compute_runs(sep_mask)

In [None]:
runs

In [None]:
run_id

In [None]:
# [데이터셋 구성]
# i-1, i, i+1 --> 트리오가 하나의 데이터
# i-1, i+1 중, 하나라도 존재하지 않는 경우는 포함 x 
#  --> ex) [1, 2, 3, 4] 인 경우, (1, 2, 3), (2, 3, 4)는 가능. 
#  -->     (1, 2), (3, 4)는 각각 앞, 뒤로 데이터가 없으므로 제외.
