In [1]:
import pandas as pd
import numpy as np
import ast

# CSV 파일 불러오기
df = pd.read_csv("CSTNET_df.csv")

# 문자열을 리스트로 변환
df['length_sequence'] = df['length_sequence'].apply(ast.literal_eval)
df['direction_sequence'] = df['direction_sequence'].apply(ast.literal_eval)
df['time_sequence'] = df['time_sequence'].apply(ast.literal_eval)


In [None]:
# 총 바이트 수
df['total_bytes'] = df['length_sequence'].apply(sum)

# 세션 지속 시간 
df['duration'] = df['time_sequence'].apply(lambda x: x[-1] - x[0] if len(x) > 1 else 0)

# Forward/Backward 패킷 개수
df['fwd_pkt_count'] = df['direction_sequence'].apply(lambda x: sum(1 for d in x if d == 1))
df['bwd_pkt_count'] = df['direction_sequence'].apply(lambda x: sum(1 for d in x if d == -1))

# 패킷 길이 통계
df['mean_pkt_len'] = df['length_sequence'].apply(lambda x: np.mean(x) if len(x) > 0 else 0)
df['max_pkt_len'] = df['length_sequence'].apply(lambda x: np.max(x) if len(x) > 0 else 0)
df['min_pkt_len'] = df['length_sequence'].apply(lambda x: np.min(x) if len(x) > 0 else 0)
df['std_pkt_len'] = df['length_sequence'].apply(lambda x: np.std(x) if len(x) > 0 else 0)
df['len_median'] = df['length_sequence'].apply(lambda x: float(np.median(x)) if len(x) else 0.0)
df['len_q25']    = df['length_sequence'].apply(lambda x: float(np.percentile(x, 25)) if len(x) else 0.0)
df['len_q75']    = df['length_sequence'].apply(lambda x: float(np.percentile(x, 75)) if len(x) else 0.0)
df['small_frac'] = df['length_sequence'].apply(lambda x: float(np.mean(np.array(x) <= 64))   if len(x) else 0.0)
df['mtu_frac']   = df['length_sequence'].apply(lambda x: float(np.mean(np.array(x) >= 1400)) if len(x) else 0.0)

# 방향/바이트 기반 피처
df['bytes_fwd'] = [sum(l for l, d in zip(L, D) if d == 1)  for L, D in zip(df['length_sequence'], df['direction_sequence'])]
df['bytes_bwd'] = [sum(l for l, d in zip(L, D) if d == -1) for L, D in zip(df['length_sequence'], df['direction_sequence'])]
df['fwd_ratio']  = df['fwd_pkt_count'] / (df['num_packets'] + 1e-6)
df['bwd_ratio']  = df['bwd_pkt_count'] / (df['num_packets'] + 1e-6)
df['bytes_ratio'] = df['bytes_fwd'] / (df['total_bytes'] + 1e-6)

# IAT(도착 간격) 상세 통계
def iat_all(time_seq):
    if len(time_seq) <= 1:
        return (0.0,)*8
    iats = np.diff(time_seq)
    mean_iat = float(np.mean(iats))
    return (
        mean_iat,
        float(np.std(iats)),
        float(np.median(iats)),
        float(np.min(iats)),
        float(np.max(iats)),
        float(np.percentile(iats, 25)),
        float(np.percentile(iats, 75)),
        float(np.std(iats) / (mean_iat + 1e-9))
    )

(df['mean_iat'], df['iat_std'], df['iat_median'], df['iat_min'], df['iat_max'],
 df['iat_q25'], df['iat_q75'], df['iat_cv']) = zip(*df['time_sequence'].apply(iat_all))

# 속도/강도 특징
dur = np.clip(df['duration'].to_numpy(dtype=float), 1e-6, None)
df['pps'] = df['num_packets'] / dur
df['bps'] = df['total_bytes'] / dur
for c in ['total_bytes', 'duration', 'pps', 'bps']:
    df[c + '_log'] = np.log1p(df[c])

# 초기 패턴 (앞 5개 패킷의 길이/방향)
K = 5
for k in range(K):
    df[f'len_{k}'] = df['length_sequence'].apply(lambda x: int(x[k]) if len(x) > k else 0)
    df[f'dir_{k}'] = df['direction_sequence'].apply(lambda x: int(x[k]) if len(x) > k else 0)

# 포트 기반 힌트
df['is_tls'] = (df['dst_port'] == 443).astype(int)
df['is_http'] = (df['dst_port'] == 80).astype(int)
df['dst_port_log'] = np.log1p(df['dst_port'])
print(f"데이터 shape: {df.shape}")
print(f"총 {len([col for col in df.columns if col not in ['src_ip', 'src_port', 'dst_ip', 'dst_port', 'protocol', 'num_packets', 'time_sequence', 'length_sequence', 'direction_sequence', 'from', 'app_name']])}개 피처 생성됨")

데이터 shape: (36941, 56)
총 45개 피처 생성됨


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# Feature (X) 와 Label (y) 분리
feature_columns = [
    # 기본
    'num_packets','total_bytes','duration',
    # 로그 스케일
    'total_bytes_log','duration_log',
    # 속도
    'pps','bps','pps_log','bps_log',
    # 길이 통계
    'mean_pkt_len','std_pkt_len','min_pkt_len','max_pkt_len',
    'len_median','len_q25','len_q75','small_frac','mtu_frac',
    # 방향/바이트
    'fwd_pkt_count','bwd_pkt_count','fwd_ratio','bwd_ratio',
    'bytes_fwd','bytes_bwd','bytes_ratio',
    # IAT
    'mean_iat','iat_std','iat_median','iat_min','iat_max','iat_q25','iat_q75','iat_cv',
    # 초기 패턴(K=5)
    'len_0','len_1','len_2','len_3','len_4',
    'dir_0','dir_1','dir_2','dir_3','dir_4',
    # 포트
    'is_tls','is_http','dst_port_log',
]

from sklearn.model_selection import GroupShuffleSplit
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report, top_k_accuracy_score
import time

# 그룹 분할(파일 단위로 누수 방지)
groups = df['from']
labels = df['app_name']

gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, test_idx = next(gss.split(df, labels, groups))

X_train = df.iloc[train_idx][feature_columns]
X_test  = df.iloc[test_idx][feature_columns]
y_train_raw = labels.iloc[train_idx]
y_test_raw  = labels.iloc[test_idx]

enc = LabelEncoder()
y_train = enc.fit_transform(y_train_raw)
y_test  = enc.transform(y_test_raw)

# RandomForest (튜닝 기본셋)
rf = RandomForestClassifier(
    n_estimators=500,
    max_depth=None,
    min_samples_leaf=2,
    max_features='sqrt',
    class_weight='balanced_subsample',
    n_jobs=-1,
    random_state=42,
    verbose=1
)

print("학습 시작...")
t0 = time.time()
rf.fit(X_train, y_train)
print(f"학습 완료 ({time.time()-t0:.2f}s)")

# 평가
y_pred = rf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
macro_f1 = f1_score(y_test, y_pred, average='macro')
top5 = top_k_accuracy_score(y_test, rf.predict_proba(X_test), k=5)

print(f"Accuracy: {acc*100:.2f}%")
print(f"Macro-F1: {macro_f1*100:.2f}%")
print(f"Top-5 Accuracy: {top5*100:.2f}%")
print("\nClassification report:")
print(classification_report(y_test, y_pred, target_names=enc.classes_))

# 피처 중요도 상위 20개
fi = pd.DataFrame({'feature': feature_columns, 'importance': rf.feature_importances_}) \
       .sort_values('importance', ascending=False)
print("\nTop-20 Feature Importances:")
print(fi.head(20).to_string(index=False))

학습 시작...


[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 16 concurrent workers.
[Parallel(n_jobs=-1)]: Done  18 tasks      | elapsed:    0.4s
[Parallel(n_jobs=-1)]: Done 168 tasks      | elapsed:    2.5s
[Parallel(n_jobs=-1)]: Done 418 tasks      | elapsed:    6.1s
[Parallel(n_jobs=-1)]: Done 600 out of 600 | elapsed:    8.6s finished
[Parallel(n_jobs=16)]: Using backend ThreadingBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done 168 tasks      | elapsed:    0.2s


학습 완료 (8.69s)


[Parallel(n_jobs=16)]: Done 418 tasks      | elapsed:    0.4s
[Parallel(n_jobs=16)]: Done 600 out of 600 | elapsed:    0.6s finished
[Parallel(n_jobs=16)]: Using backend ThreadingBackend with 16 concurrent workers.
[Parallel(n_jobs=16)]: Done  18 tasks      | elapsed:    0.0s
[Parallel(n_jobs=16)]: Done 168 tasks      | elapsed:    0.2s
[Parallel(n_jobs=16)]: Done 418 tasks      | elapsed:    0.4s
[Parallel(n_jobs=16)]: Done 600 out of 600 | elapsed:    0.6s finished


Accuracy: 81.13%
Macro-F1: 80.95%
Top-5 Accuracy: 93.72%

Classification report:
                        precision    recall  f1-score   support

               163.com       0.69      0.63      0.66        92
             51cto.com       0.84      0.86      0.85        88
            alicdn.com       0.75      0.70      0.73       110
            alipay.com       0.88      0.84      0.86       100
              amap.com       0.98      0.99      0.99       103
         amazonaws.com       0.88      0.88      0.88        78
        ampproject.org       0.81      0.81      0.81       107
             apple.com       0.80      0.51      0.62       100
             arxiv.org       0.80      0.86      0.83        90
              asus.com       0.91      0.95      0.93        73
         azureedge.net       0.88      0.96      0.92        97
             baidu.com       0.71      0.81      0.75        99
          bilibili.com       0.89      0.87      0.88       107
          biligame.com