In [None]:
import torch
from torch_geometric.loader import NeighborLoader

from utils.utils import train_batchwise, test_batchwise, train_attack_model, evaluate_attack_model, confidence_adaptive_noise
from utils.data import split_multilabel_dataset
from models.GraphSAGE import DeepGraphSAGE

## Graph Load

In [None]:
load_path = "./project/dataset/google/gplus_graph.pt"
data, mlb, all_labels = torch.load(load_path, weights_only=False)
print(f"Loaded data with {data.num_nodes} nodes and {data.num_edges} edges")

Loaded data with 256780 nodes and 30237805 edges


### 라벨 개수가 100개 이하인 경우 제거

In [None]:
# 라벨이 존재하는 노드만 마스크 생성
has_label = data.y.sum(dim=1) > 0
labeled_idx = has_label.nonzero(as_tuple=False).view(-1)

# 노드 특성과 라벨 필터링
data.x = data.x[labeled_idx]
data.y = data.y[labeled_idx]

# 라벨의 전체 분포 계산
label_counts = data.y.sum(dim=0)

# 라벨 개수가 100개 이하인 경우 제거
keep_label_indices = (label_counts > 300).nonzero(as_tuple=True)[0]
data.y = data.y[:, keep_label_indices]

print(f"[Filter Labels] Kept {len(keep_label_indices)} labels (with > 300 samples each)")

# edge_index도 labeled_idx만 포함하도록 재매핑
# 먼저 node id → 새 인덱스 매핑 생성
old_to_new = {int(old_idx): new_idx for new_idx, old_idx in enumerate(labeled_idx.tolist())}

# 유효한 edge만 필터링
src, dst = data.edge_index
mask = has_label[src] & has_label[dst]
src = src[mask]
dst = dst[mask]

# 인덱스 재매핑
mapped_src = torch.tensor([old_to_new[int(i)] for i in src.tolist()], dtype=torch.long)
mapped_dst = torch.tensor([old_to_new[int(i)] for i in dst.tolist()], dtype=torch.long)
data.edge_index = torch.stack([mapped_src, mapped_dst], dim=0)

print(f"[Filtered] {data.num_nodes} nodes | {data.num_edges} edges | {data.y.shape[1]} labels")

[Filter Labels] Kept 73 labels (with > 300 samples each)
[Filtered] 39204 nodes | 2147251 edges | 73 labels


## DataLoad

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 데이터셋 분할
target_train_idx, target_test_idx, shadow_train_idx, shadow_test_idx = split_multilabel_dataset(data)

# Target Train
target_train_loader = NeighborLoader(
    data,
    input_nodes=target_train_idx,
    num_neighbors=[15, 10, 5],
    batch_size=128,
    shuffle=True
)

# Target Test
target_test_loader = NeighborLoader(
    data,
    input_nodes=target_test_idx,
    num_neighbors=[15, 10, 5],
    batch_size=128,
    shuffle=True
)

print("Load Target Model Dataset")

# Shadow Train
shadow_train_loader = NeighborLoader(
    data,
    input_nodes=shadow_train_idx,
    num_neighbors=[15, 10, 5],
    batch_size=128,
    shuffle=True
)

# Shadow Test
shadow_test_loader = NeighborLoader(
    data,
    input_nodes=shadow_test_idx,
    num_neighbors=[15, 10, 5],
    batch_size=128,
    shuffle=True
)
print("Load Shadow Model Dataset")

Load Target Model Dataset
Load Shadow Model Dataset


## Target Model

In [None]:
target_model = DeepGraphSAGE(in_channels=data.x.shape[1], hidden_channels=64, out_channels=data.y.shape[1]).to(device)

print("Load Model")

label_count = data.y.sum(dim=0)
pos_weight = (data.y.shape[0] - label_count) / (label_count + 1e-5)

train_batchwise(target_model, target_train_loader, device, 
                epochs=300, learning_rate=0.01, 
                pos_weight=pos_weight, noise=lambda x: confidence_adaptive_noise(x, base_scale=0.2, min_scale=0.05, gamma=1.0))

torch.save(target_model.state_dict(), "./weights/google/300/target/target.pt")

Load Model


  3%|▎         | 10/299 [00:22<11:16,  2.34s/it]

[Epoch 10] Loss: 44.2492


  7%|▋         | 20/299 [00:44<10:22,  2.23s/it]

[Epoch 20] Loss: 44.8719


 10%|█         | 30/299 [01:07<10:06,  2.26s/it]

[Epoch 30] Loss: 44.0705


 13%|█▎        | 40/299 [01:32<11:54,  2.76s/it]

[Epoch 40] Loss: 44.3050


 17%|█▋        | 50/299 [01:58<09:21,  2.25s/it]

[Epoch 50] Loss: 49.6841


 20%|██        | 60/299 [02:26<11:13,  2.82s/it]

[Epoch 60] Loss: 44.6848


 23%|██▎       | 70/299 [02:53<09:28,  2.48s/it]

[Epoch 70] Loss: 44.3723


 27%|██▋       | 80/299 [03:14<07:44,  2.12s/it]

[Epoch 80] Loss: 44.7696


 30%|███       | 90/299 [03:36<07:55,  2.28s/it]

[Epoch 90] Loss: 44.5692


 33%|███▎      | 100/299 [04:00<07:49,  2.36s/it]

[Epoch 100] Loss: 44.6131


 37%|███▋      | 110/299 [04:22<07:10,  2.28s/it]

[Epoch 110] Loss: 44.5805


 40%|████      | 120/299 [04:47<07:10,  2.41s/it]

[Epoch 120] Loss: 44.6423


 43%|████▎     | 130/299 [05:15<08:26,  3.00s/it]

[Epoch 130] Loss: 44.6302


 47%|████▋     | 140/299 [05:43<06:29,  2.45s/it]

[Epoch 140] Loss: 44.5737


 50%|█████     | 150/299 [06:05<05:19,  2.14s/it]

[Epoch 150] Loss: 46.7180


 54%|█████▎    | 160/299 [06:27<05:08,  2.22s/it]

[Epoch 160] Loss: 45.2194


 57%|█████▋    | 170/299 [06:50<04:53,  2.28s/it]

[Epoch 170] Loss: 45.5021


 60%|██████    | 180/299 [07:21<06:32,  3.30s/it]

[Epoch 180] Loss: 45.0762


 64%|██████▎   | 190/299 [07:45<04:18,  2.37s/it]

[Epoch 190] Loss: 44.9772


 67%|██████▋   | 200/299 [08:07<03:38,  2.21s/it]

[Epoch 200] Loss: 45.5066


 70%|███████   | 210/299 [08:29<03:30,  2.37s/it]

[Epoch 210] Loss: 45.2425


 74%|███████▎  | 220/299 [08:53<03:05,  2.35s/it]

[Epoch 220] Loss: 44.9501


 77%|███████▋  | 230/299 [09:15<02:29,  2.17s/it]

[Epoch 230] Loss: 44.9867


 80%|████████  | 240/299 [09:37<02:10,  2.21s/it]

[Epoch 240] Loss: 45.2586


 84%|████████▎ | 250/299 [09:57<01:37,  1.99s/it]

[Epoch 250] Loss: 45.2711


 87%|████████▋ | 260/299 [10:16<01:14,  1.92s/it]

[Epoch 260] Loss: 45.1456


 90%|█████████ | 270/299 [10:36<00:56,  1.94s/it]

[Epoch 270] Loss: 46.3767


 94%|█████████▎| 280/299 [10:55<00:36,  1.92s/it]

[Epoch 280] Loss: 45.4845


 97%|█████████▋| 290/299 [11:14<00:17,  1.92s/it]

[Epoch 290] Loss: 45.8201


100%|██████████| 299/299 [11:32<00:00,  2.31s/it]


In [None]:
# 모델 선언 (입력 차원과 출력 차원 맞추기)
target_model = DeepGraphSAGE(in_channels=data.x.shape[1], hidden_channels=64, out_channels=data.y.shape[1]).to(device)

# weight 파일 로드
target_model.load_state_dict(torch.load("./weights/google/300/target/.pt", weights_only=False))
target_model.eval()

test_batchwise(target_model, target_test_loader, device, threshold=0.97)

[Test] F1-micro: 0.7299 | F1-macro: 0.6995 | Precision-macro: 0.6034 | Recall-macro: 0.8802
F1 per label:
[0.732 0.929 0.744 0.786 0.689 0.903 0.    0.917 0.973 0.913 0.5   0.042
 0.866 0.978 0.649 0.634 0.702 0.802 0.781 0.657 0.644 0.838 0.982 0.672
 0.    0.973 0.456 0.588 0.477 0.859 0.839 0.    0.692 0.681 0.485 1.
 0.74  0.722 0.    0.677 0.657 1.    0.993 0.    0.69  0.987 0.    0.987
 0.448 0.586 0.798 0.979 0.747 0.    0.865 0.838 0.932 0.738 0.851 0.862
 1.    0.794 0.748 0.722 0.838 0.905 0.964 0.817 0.944 0.604 0.724 0.951
 0.571]


(0.7299459130037328,
 0.6994672550637632,
 0.6033945587420071,
 0.8802256417422668,
 array([0.73191489, 0.92857143, 0.74418605, 0.78606965, 0.68852459,
        0.9025641 , 0.        , 0.91712707, 0.97260274, 0.91304348,
        0.5       , 0.04166667, 0.86624204, 0.97752809, 0.64944649,
        0.63436123, 0.70212766, 0.80152672, 0.78095238, 0.65660377,
        0.6440678 , 0.83783784, 0.98214286, 0.67164179, 0.        ,
        0.97333333, 0.4556962 , 0.58823529, 0.47668394, 0.85897436,
        0.83937824, 0.        , 0.69172932, 0.68148148, 0.48453608,
        1.        , 0.74025974, 0.72189349, 0.        , 0.67680608,
        0.65660377, 1.        , 0.99300699, 0.        , 0.68989547,
        0.98684211, 0.        , 0.98684211, 0.44776119, 0.58646617,
        0.7983871 , 0.97916667, 0.74708171, 0.        , 0.86549708,
        0.83832335, 0.93203883, 0.73846154, 0.85067873, 0.86222222,
        1.        , 0.79396985, 0.7483871 , 0.72189349, 0.83832335,
        0.90502793, 0.96402878, 

## Shadow Model

In [16]:
# Shadow model training (no augmentation)
shadow_model = DeepGraphSAGE(in_channels=data.x.shape[1], hidden_channels=64, out_channels=data.y.shape[1]).to(device)

label_count = data.y.sum(dim=0)
pos_weight = (data.y.shape[0] - label_count) / (label_count + 1e-5)

train_batchwise(shadow_model, shadow_train_loader, device, epochs=300, learning_rate=0.01, pos_weight=pos_weight)

torch.save(shadow_model.state_dict(), "./weights/google/300/shadow/shadow.pt")

  3%|▎         | 10/299 [00:13<06:20,  1.32s/it]

[Epoch 10] Loss: 29.7433


  7%|▋         | 20/299 [00:26<06:08,  1.32s/it]

[Epoch 20] Loss: 29.2925


 10%|█         | 30/299 [00:39<05:53,  1.31s/it]

[Epoch 30] Loss: 29.3505


 13%|█▎        | 40/299 [00:52<05:37,  1.31s/it]

[Epoch 40] Loss: 29.3814


 17%|█▋        | 50/299 [01:05<05:31,  1.33s/it]

[Epoch 50] Loss: 29.3942


 20%|██        | 60/299 [01:19<05:15,  1.32s/it]

[Epoch 60] Loss: 29.2177


 23%|██▎       | 70/299 [01:32<05:04,  1.33s/it]

[Epoch 70] Loss: 29.2312


 27%|██▋       | 80/299 [01:45<04:49,  1.32s/it]

[Epoch 80] Loss: 29.4250


 30%|███       | 90/299 [01:58<04:36,  1.32s/it]

[Epoch 90] Loss: 28.9886


 33%|███▎      | 100/299 [02:12<04:29,  1.35s/it]

[Epoch 100] Loss: 29.2570


 37%|███▋      | 110/299 [02:26<04:33,  1.45s/it]

[Epoch 110] Loss: 29.0014


 40%|████      | 120/299 [02:41<04:23,  1.47s/it]

[Epoch 120] Loss: 29.3858


 43%|████▎     | 130/299 [02:56<04:09,  1.47s/it]

[Epoch 130] Loss: 29.2680


 47%|████▋     | 140/299 [03:11<03:58,  1.50s/it]

[Epoch 140] Loss: 29.8416


 50%|█████     | 150/299 [03:27<04:03,  1.63s/it]

[Epoch 150] Loss: 29.8437


 54%|█████▎    | 160/299 [03:42<03:23,  1.46s/it]

[Epoch 160] Loss: 29.5209


 57%|█████▋    | 170/299 [03:57<03:15,  1.51s/it]

[Epoch 170] Loss: 29.6281


 60%|██████    | 180/299 [04:11<02:53,  1.46s/it]

[Epoch 180] Loss: 29.6245


 64%|██████▎   | 190/299 [04:26<02:39,  1.47s/it]

[Epoch 190] Loss: 29.7412


 67%|██████▋   | 200/299 [04:41<02:27,  1.49s/it]

[Epoch 200] Loss: 29.6775


 70%|███████   | 210/299 [04:56<02:10,  1.47s/it]

[Epoch 210] Loss: 30.2550


 74%|███████▎  | 220/299 [05:10<01:54,  1.45s/it]

[Epoch 220] Loss: 29.7785


 77%|███████▋  | 230/299 [05:25<01:42,  1.48s/it]

[Epoch 230] Loss: 30.3682


 80%|████████  | 240/299 [05:39<01:26,  1.46s/it]

[Epoch 240] Loss: 29.9281


 84%|████████▎ | 250/299 [05:54<01:11,  1.45s/it]

[Epoch 250] Loss: 30.0207


 87%|████████▋ | 260/299 [06:08<00:56,  1.44s/it]

[Epoch 260] Loss: 30.2871


 90%|█████████ | 270/299 [06:23<00:42,  1.46s/it]

[Epoch 270] Loss: 30.1545


 94%|█████████▎| 280/299 [06:37<00:27,  1.44s/it]

[Epoch 280] Loss: 29.9040


 97%|█████████▋| 290/299 [06:52<00:12,  1.43s/it]

[Epoch 290] Loss: 30.1070


100%|██████████| 299/299 [07:05<00:00,  1.42s/it]


In [None]:
shadow_model = DeepGraphSAGE(in_channels=data.x.shape[1], hidden_channels=64, out_channels=data.y.shape[1]).to(device)

# weight 파일 로드
shadow_model.load_state_dict(torch.load("./weights/google/300/shadow/shadow.pt", weights_only=False))
shadow_model.eval()

test_batchwise(shadow_model, shadow_test_loader, device, threshold=0.97)

[Test] F1-micro: 0.7277 | F1-macro: 0.7127 | Precision-macro: 0.6482 | Recall-macro: 0.8672
F1 per label:
[0.691 0.943 0.792 0.833 0.645 0.904 0.    0.926 0.986 0.912 0.847 0.066
 0.845 0.938 0.627 0.723 0.711 0.831 0.729 0.628 0.688 0.899 0.952 0.598
 0.279 0.957 0.24  0.582 0.573 0.783 0.855 0.    0.556 0.7   0.43  1.
 0.85  0.679 0.    0.691 0.722 0.973 0.988 0.    0.944 0.961 0.    0.992
 0.554 0.99  0.718 0.952 0.    0.122 0.87  0.85  0.977 0.718 0.785 0.909
 1.    0.79  0.839 0.75  0.86  0.917 0.987 0.86  0.972 0.652 0.829 0.986
 0.674]


(0.7276605298607993,
 0.7127127068302331,
 0.6481626969513311,
 0.8672433192753384,
 array([0.69064748, 0.94308943, 0.79166667, 0.83333333, 0.6446281 ,
        0.9037037 , 0.        , 0.92561983, 0.98630137, 0.91176471,
        0.84684685, 0.06557377, 0.84536082, 0.93793103, 0.62721893,
        0.72340426, 0.71111111, 0.83050847, 0.72857143, 0.62827225,
        0.688     , 0.89908257, 0.95238095, 0.59756098, 0.2791762 ,
        0.95652174, 0.24      , 0.58208955, 0.57317073, 0.7826087 ,
        0.85526316, 0.        , 0.55629139, 0.7       , 0.4295302 ,
        1.        , 0.84955752, 0.67924528, 0.        , 0.69090909,
        0.72222222, 0.97297297, 0.98765432, 0.        , 0.94444444,
        0.96103896, 0.        , 0.99159664, 0.55421687, 0.99047619,
        0.71830986, 0.9516129 , 0.        , 0.12244898, 0.86956522,
        0.84955752, 0.97674419, 0.71830986, 0.78518519, 0.90909091,
        1.        , 0.78980892, 0.83928571, 0.75      , 0.85964912,
        0.91666667, 0.98734177, 

## Attack Model

In [18]:
attack_clf = train_attack_model(shadow_model, data, shadow_train_idx, shadow_test_idx, top_k=30, device=device)
evaluate_attack_model(attack_clf, target_model, data, target_train_idx, target_test_idx, top_k=30, device=device)

[Attack Evaluation] Accuracy: 0.7000 | AUC: 0.5018


(0.6999829946433126, 0.5018297623434288)