<a href="https://colab.research.google.com/github/funway/nid-imbalance-study/blob/main/preprocessing/preprocess.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CSE-CIC-IDS2018 数据集介绍

🚀 NYIT 880 | 🧑🏻‍💻 funway

## 1. 数据集下载
使用 aws 命令行工具，从云存储中下载 CSV 文件:
```
aws s3 sync --no-sign-request --region us-east-2 "s3://cse-cic-ids2018/Processed Traffic Data for ML Algorithms/" ./
```


## 2. 数据集内容
正常情况每个 csv 文件包含 80 列数据:
- 79 列特征 (`features`)
- 1 列结果标签 (`label`)

### 2.1 特征项
```
['ACK Flag Cnt', 'Active Max', 'Active Mean', 'Active Min', 'Active Std', 'Bwd Blk Rate Avg', 'Bwd Byts/b Avg', 'Bwd Header Len', 'Bwd IAT Max', 'Bwd IAT Mean', 'Bwd IAT Min', 'Bwd IAT Std', 'Bwd IAT Tot', 'Bwd PSH Flags', 'Bwd Pkt Len Max', 'Bwd Pkt Len Mean', 'Bwd Pkt Len Min', 'Bwd Pkt Len Std', 'Bwd Pkts/b Avg', 'Bwd Pkts/s', 'Bwd Seg Size Avg', 'Bwd URG Flags', 'CWE Flag Count', 'Down/Up Ratio', 'Dst Port', 'ECE Flag Cnt', 'FIN Flag Cnt', 'Flow Byts/s', 'Flow Duration', 'Flow IAT Max', 'Flow IAT Mean', 'Flow IAT Min', 'Flow IAT Std', 'Flow Pkts/s', 'Fwd Act Data Pkts', 'Fwd Blk Rate Avg', 'Fwd Byts/b Avg', 'Fwd Header Len', 'Fwd IAT Max', 'Fwd IAT Mean', 'Fwd IAT Min', 'Fwd IAT Std', 'Fwd IAT Tot', 'Fwd PSH Flags', 'Fwd Pkt Len Max', 'Fwd Pkt Len Mean', 'Fwd Pkt Len Min', 'Fwd Pkt Len Std', 'Fwd Pkts/b Avg', 'Fwd Pkts/s', 'Fwd Seg Size Avg', 'Fwd Seg Size Min', 'Fwd URG Flags', 'Idle Max', 'Idle Mean', 'Idle Min', 'Idle Std', 'Init Bwd Win Byts', 'Init Fwd Win Byts', 'Label', 'PSH Flag Cnt', 'Pkt Len Max', 'Pkt Len Mean', 'Pkt Len Min', 'Pkt Len Std', 'Pkt Len Var', 'Pkt Size Avg', 'Protocol', 'RST Flag Cnt', 'SYN Flag Cnt', 'Subflow Bwd Byts', 'Subflow Bwd Pkts', 'Subflow Fwd Byts', 'Subflow Fwd Pkts', 'Timestamp', 'Tot Bwd Pkts', 'Tot Fwd Pkts', 'TotLen Bwd Pkts', 'TotLen Fwd Pkts', 'URG Flag Cnt']
```

**例外**:
- 只有 Thuesday-20-02-2018_TrafficForML_CICFlowMeter.csv 文件多了 `['Dst IP', 'Src Port', 'Flow ID', 'Src IP']` 四个特征项, 需要删除。

### 2.2 字符型特征
只有 `Label` 特征是字符串，表示该行数据是某种类型的攻击。
其余特征都是数值型。

### 2.3 标签值
 Label 列共 15 种值:
 ```
 ['Benign', 'Bot', 'Brute Force -Web', 'Brute Force -XSS', 'DDOS attack-HOIC',
 'DDOS attack-LOIC-UDP', 'DDoS attacks-LOIC-HTTP', 'DoS attacks-GoldenEye',
 'DoS attacks-Hulk', 'DoS attacks-SlowHTTPTest', 'DoS attacks-Slowloris',
 'FTP-BruteForce', 'Infilteration', 'SQL Injection', 'SSH-Bruteforce']
 ```

# CSE-CIC-IDS2018 数据集预处理

🚀 NYIT 880 | 🧑🏻‍💻 funway


In [None]:
### 挂载 Google Drive ###
import os
from google.colab import drive

if not os.path.exists('/content/drive/MyDrive'):
    drive.mount('/content/drive')

# 打印目录
!ls -thl /content/drive/MyDrive/NYIT/880/

total 16K
drwx------ 2 root root 4.0K Jun  4 16:53 data
drwx------ 2 root root 4.0K Jun  4 16:53 data_bak
drwx------ 2 root root 4.0K May 28 05:33 code
drwx------ 2 root root 4.0K May 28 05:31 dataset


In [None]:
### Modules ###
from pathlib import Path
from datetime import datetime
from collections import Counter
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler


### Globals ###
## 数据文件目录
dataset = 'cse-cic-ids2018'
project_folder = Path('/content/drive/MyDrive/NYIT/880')
dataset_folder = project_folder / 'dataset' / dataset
preprocessed_folder = project_folder / 'data/preprocessed'
balanced_folder = project_folder / 'data/balanced'
splits_folder = preprocessed_folder / 'splits'


## csv 文件匹配
csv_reg = '*.csv'
csv_files = list(dataset_folder.rglob(csv_reg))
for csv in csv_files:
    print(f'csv file: {csv}, {type(csv)}')
    pass

## 无用的特征列
cols_to_drop = ['Flow ID', 'Src IP', 'Dst IP', 'Src Port', 'Timestamp']

## Label 列的所有可能值(有序)
unique_labels = ['Benign', 'Bot', 'Brute Force -Web', 'Brute Force -XSS', 'DDOS attack-HOIC', 'DDOS attack-LOIC-UDP', 'DDoS attacks-LOIC-HTTP', 'DoS attacks-GoldenEye', 'DoS attacks-Hulk', 'DoS attacks-SlowHTTPTest', 'DoS attacks-Slowloris', 'FTP-BruteForce', 'Infilteration', 'SQL Injection', 'SSH-Bruteforce']
label_mapping = {label: idx for idx, label in enumerate(unique_labels)}
print(f"[{datetime.now().strftime('%x %X')}] 🏷️ Label mapping: {label_mapping}")


### Utilities ###
def get_label_counts(y: np.ndarray) -> dict[int, int]:
  return {int(k): v for k, v in sorted(Counter(y).items())}

csv file: /content/drive/MyDrive/NYIT/880/dataset/cse-cic-ids2018/Friday-02-03-2018_TrafficForML_CICFlowMeter.csv, <class 'pathlib.PosixPath'>
csv file: /content/drive/MyDrive/NYIT/880/dataset/cse-cic-ids2018/Friday-16-02-2018_TrafficForML_CICFlowMeter.csv, <class 'pathlib.PosixPath'>
csv file: /content/drive/MyDrive/NYIT/880/dataset/cse-cic-ids2018/Friday-23-02-2018_TrafficForML_CICFlowMeter.csv, <class 'pathlib.PosixPath'>
csv file: /content/drive/MyDrive/NYIT/880/dataset/cse-cic-ids2018/Thursday-01-03-2018_TrafficForML_CICFlowMeter.csv, <class 'pathlib.PosixPath'>
csv file: /content/drive/MyDrive/NYIT/880/dataset/cse-cic-ids2018/Thursday-15-02-2018_TrafficForML_CICFlowMeter.csv, <class 'pathlib.PosixPath'>
csv file: /content/drive/MyDrive/NYIT/880/dataset/cse-cic-ids2018/Thursday-22-02-2018_TrafficForML_CICFlowMeter.csv, <class 'pathlib.PosixPath'>
csv file: /content/drive/MyDrive/NYIT/880/dataset/cse-cic-ids2018/Wednesday-14-02-2018_TrafficForML_CICFlowMeter.csv, <class 'pathlib.Po

## Get all unique labels

In [None]:
## 获取 Label 的所有可能值 ##

# unique_labels = []

if not unique_labels:
    all_labels_count = {}  # 统计所有 csv 文件的 labels 种类与数量

    # 遍历每个 csv 文件
    for csv in csv_files:
        print(f'Reading csv file: {Path(csv).name}')

        csv_labels_count = {}  # 统计当前 CSV 文件的 labels 数量

        # 分块读取
        chunk_size = 100000
        for chunk in pd.read_csv(csv, usecols=['Label'], chunksize=chunk_size):
            if 'Label' in chunk.columns:
                # 统计当前 chunk 的 labels 种类与数量，返回 {'Benign': 100, 'Bot': 99} 字典
                chunk_labels_count = chunk['Label'].value_counts().to_dict()

                for label, count in chunk_labels_count.items():
                    # 更新当前 csv 文件的 labels 统计
                    csv_labels_count[label] = csv_labels_count.get(label, 0) + count

                    # 更新所有 csv 文件的 labels 统计
                    all_labels_count[label] = all_labels_count.get(label, 0) + count

        # 打印当前 csv 的 unique labels
        print(f'  unique labels: [{len(csv_labels_count)}], {dict(sorted(csv_labels_count.items()))}\n')
        pass

    # 打印所有的 unique labels
    print(f'All unique labels count: [{len(all_labels_count)}] \n{dict(sorted(all_labels_count.items()))}\n')

    # 如果 'Label' 存在则删除
    all_labels_count.pop('Label', None)

    # 转换成有序列表
    unique_labels = sorted(all_labels_count.keys())
    label_mapping = {label: idx for idx, label in enumerate(unique_labels)}

    print(f"[{datetime.now().strftime('%x %X')}] All unique labels: [{len(label_mapping)}] (removed 'Label')\n{label_mapping}", )
else:
    print(f"[{datetime.now().strftime('%x %X')}] unique_labels has been set: [{len(label_mapping)}]\n{label_mapping}")

[06/04/25 17:31:08] unique_labels has been set: [15]
{'Benign': 0, 'Bot': 1, 'Brute Force -Web': 2, 'Brute Force -XSS': 3, 'DDOS attack-HOIC': 4, 'DDOS attack-LOIC-UDP': 5, 'DDoS attacks-LOIC-HTTP': 6, 'DoS attacks-GoldenEye': 7, 'DoS attacks-Hulk': 8, 'DoS attacks-SlowHTTPTest': 9, 'DoS attacks-Slowloris': 10, 'FTP-BruteForce': 11, 'Infilteration': 12, 'SQL Injection': 13, 'SSH-Bruteforce': 14}


## Data Cleaning 数据清洗
1. 删除无用特征列
2. 删除异常行(行数据为列名)
3. 数值型特征列 类型转换为 float32 (异常值转换为 NaN)
4. 处理 Inf 值 (删除行)
5. 处理 NaN 值 (删除行)
6. 标签列进行数值化编码



In [None]:
cleaned_folder = preprocessed_folder / 'cleaned.parquet'
reclean = False

if not cleaned_folder.exists() or reclean:
    # 确保目录存在
    cleaned_folder.mkdir(parents=True, exist_ok=True)

    for i, csv in enumerate(csv_files):
        # 加载文件
        print(f"[{datetime.now().strftime('%x %X')}] Loading csv file [{i+1}]: {csv.name}")
        df = pd.read_csv(csv, nrows=None, low_memory=False)
        print(f'  包含[{len(df.columns)}]列特征: {sorted(df.columns.tolist())}')
        print(f"  Label 列的值: {df['Label'].value_counts().to_dict()}")


        ## 删除部分无用的特征列 ##
        cols_to_drop_exist = [col for col in cols_to_drop if col in df.columns]
        df = df.drop(cols_to_drop_exist, axis=1)  # axis=1 表示删除列
        print(f'  ❎ 删除部分无用的特征列: {cols_to_drop_exist}')
        print(f'    删除后剩余[{len(df.columns)}]列特征')


        ## 删除 Label 列的值等于 'Label' 的行 ##
        # 因为有几个文件在某些行出现了一整行的特征名
        if 'Label' in df.columns and ('Label' in df['Label'].values):
            df = df[df['Label'] != 'Label']
            print(f'  ❎ 删除其中 Label 列的值等于 "Label" 的行')
            print(f"    删除后 Label 列的值: {df['Label'].value_counts().to_dict()}")


        ## 数值列特征转换成 数值类型 而不是 object 类型 ##
        print(f'  🔁 数值特征类型转换 ⇨ float32')
        # 提取数值特征列，排除 'Label' 列
        features = df.columns.difference(['Label'])
        # 对数值特征列进行数值转换(无法转换成数值的，强制设为 NaN)
        df[features] = df[features].apply(pd.to_numeric, errors='coerce').astype('float32')


        ## 处理 Inf 值 ##
        print(f'  ⚠️ 处理 Inf 值. 当前 shape={df.shape}')
        print(f'    正无穷 (+Inf) 数量: {(df == np.inf).sum().sum()}, 负无穷 (-Inf) 数量: {(df == -np.inf).sum().sum()}')
        # 方法一: 删除该行
        df = df[~df.isin([np.inf, -np.inf]).any(axis=1)]
        df.reset_index(drop=True, inplace=True)

        # 方法二: 替换为对应列的最大/最小值
        # max_value = df.replace([np.inf, -np.inf], np.nan).max()
        # min_value = df.replace([np.inf, -np.inf], np.nan).min()
        # df.replace(np.inf, max_value, inplace=True)
        # df.replace(-np.inf, min_value, inplace=True)
        # 方法三: 替换成 NaN
        # df = df.replace([np.inf, -np.inf], np.nan)

        print(f'    处理后 shape={df.shape}')


        ## 处理 NaN 值 ##
        print(f'  ⚠️ 处理 NaN 值. 包含 NaN 个数: {df.isna().sum().sum()}')
        # 方法一: 删除该行
        df.dropna(inplace=True)
        df.reset_index(drop=True, inplace=True)
        # 方法二: 填充值
        # df['Label'] = df['Label'].fillna('Benign')  # 填充 Label 列
        # df = df.fillna(0)  # 填充其他列
        print(f'    处理后 shape={df.shape}')


        ## 对 Label 标签列进行数值化编码 ##
        print(f"  🔁 标签列数值化编码 Numericalization Encoding...")
        def encode_label(label):
            if label in label_mapping:
                return label_mapping[label]
            else:
                raise ValueError(f"Unknown label '{label}' encountered during encoding.")

        print(f"    转换前 Label 列的值: {df['Label'].value_counts().to_dict()}")
        df['Label'] = df['Label'].apply(encode_label).astype('int32')
        print(f"    转换后 Label 列的值: {df['Label'].value_counts().to_dict()}")


        ## 保存文件 ##
        # df.to_csv('combined.csv', mode='a', header=(i==0), index=False)
        output_file = cleaned_folder / f'part-{i+1:03d}.parquet'
        df.to_parquet(output_file, index=False)
        print(f"[{datetime.now().strftime('%x %X')}] 💾 保存文件 [{i+1}]: {csv.name} >> {output_file}")

        print("-------------------------")
else:
    print(f"[{datetime.now().strftime('%x %X')}] ⏭️ {cleaned_folder} 已存在, 跳过该步骤")

[06/04/25 17:31:15] Loading csv file [1]: Friday-02-03-2018_TrafficForML_CICFlowMeter.csv
  包含[80]列特征: ['ACK Flag Cnt', 'Active Max', 'Active Mean', 'Active Min', 'Active Std', 'Bwd Blk Rate Avg', 'Bwd Byts/b Avg', 'Bwd Header Len', 'Bwd IAT Max', 'Bwd IAT Mean', 'Bwd IAT Min', 'Bwd IAT Std', 'Bwd IAT Tot', 'Bwd PSH Flags', 'Bwd Pkt Len Max', 'Bwd Pkt Len Mean', 'Bwd Pkt Len Min', 'Bwd Pkt Len Std', 'Bwd Pkts/b Avg', 'Bwd Pkts/s', 'Bwd Seg Size Avg', 'Bwd URG Flags', 'CWE Flag Count', 'Down/Up Ratio', 'Dst Port', 'ECE Flag Cnt', 'FIN Flag Cnt', 'Flow Byts/s', 'Flow Duration', 'Flow IAT Max', 'Flow IAT Mean', 'Flow IAT Min', 'Flow IAT Std', 'Flow Pkts/s', 'Fwd Act Data Pkts', 'Fwd Blk Rate Avg', 'Fwd Byts/b Avg', 'Fwd Header Len', 'Fwd IAT Max', 'Fwd IAT Mean', 'Fwd IAT Min', 'Fwd IAT Std', 'Fwd IAT Tot', 'Fwd PSH Flags', 'Fwd Pkt Len Max', 'Fwd Pkt Len Mean', 'Fwd Pkt Len Min', 'Fwd Pkt Len Std', 'Fwd Pkts/b Avg', 'Fwd Pkts/s', 'Fwd Seg Size Avg', 'Fwd Seg Size Min', 'Fwd URG Flags', '

## Data Trimming
- 删除全零值的特征列
- 删除存在 小于 -1 的值的行
- 删除过多的标签样本 (尤其是 `Benign` 样本, 只保留 2000000 行)

In [None]:
trimed_file = preprocessed_folder / 'trimed_data.parquet'
re_trim = False

# trim_to 字典指定要裁剪的 目标标签 与 目标数量 ⇨ df[label] <= trim_to[label]
trim_to = {0:2000000, 1:300000, 3:300000} # 因为标签 1,3 本来就没有30万条数据, 所以这里并不会对标签 1,3 进行裁剪

if not trimed_file.exists() or re_trim:
    ## 一次性读取全部数据 ##
    print(f"[{datetime.now().strftime('%x %X')}] 🚀 Loading cleaned data from {cleaned_folder}")

    df = pd.read_parquet(cleaned_folder)
    print(f"[{datetime.now().strftime('%x %X')}] shape={df.shape}")
    print(f"[{datetime.now().strftime('%x %X')}] Label 列的值: {dict(sorted(df['Label'].value_counts().items()))}")


    ## 判断 df 中是否有全零的列，有的话删除 ##
    zero_columns = df.columns[(df==0).all()]
    print(f"[{datetime.now().strftime('%x %X')}] 🧹 检测到全零值的列共 {len(zero_columns)} 个: {[df.columns.get_loc(col) for col in zero_columns]}")

    df = df.drop(zero_columns, axis=1)
    print(f"[{datetime.now().strftime('%x %X')}] 🧹 删除后 shape={df.shape}")

    ## 处理小于 -1 的值 ##
    rows_with_negatives = df[df.lt(-1).any(axis=1)]
    print(f"[{datetime.now().strftime('%x %X')}] ⚠️ 处理小于 -1 的值. 共找到 {len(rows_with_negatives)} 行包含小于 -1 的值")
    df.drop(index=rows_with_negatives.index, inplace=True)
    print(f"[{datetime.now().strftime('%x %X')}] ⚠️ 处理后 shape={df.shape}")
    print(f"[{datetime.now().strftime('%x %X')}] ⚠️ Label 列的值: {dict(sorted(df['Label'].value_counts().items()))}")

    ## 开始裁剪 ##
    # 计算每行是否应该保留
    df['__keep__'] = False
    for label, max_count in trim_to.items():
        mask = (df['Label'] == label)
        if len(df[mask]) <= max_count:
            keep_indices = df[mask].index
        else:
            keep_indices = df[mask].sample(n=max_count, random_state=42).index
        df.loc[keep_indices, '__keep__'] = True

    # 未出现在 trim_to 中的类别，默认保留全部
    df.loc[~df['Label'].isin(trim_to.keys()), '__keep__'] = True

    # 过滤并丢弃标记列
    df = df[df['__keep__']].drop(columns='__keep__').reset_index(drop=True)

    print(f"[{datetime.now().strftime('%x %X')}] ✂️ 裁剪后 shape={df.shape}")
    print(f"[{datetime.now().strftime('%x %X')}] ✂️ Label 列的值: {dict(sorted(df['Label'].value_counts().items()))}")


    ## 保存裁剪后的数据 ##
    df.to_parquet(trimed_file, index=False)
    print(f"[{datetime.now().strftime('%x %X')}] 💾 保存裁剪后的数据: {trimed_file}")
    pass
else:
    print(f"[{datetime.now().strftime('%x %X')}] ⏭️ {trimed_file} 已存在, 直接加载")
    df = pd.read_parquet(trimed_file)
    print(f"[{datetime.now().strftime('%x %X')}] shape={df.shape}")
    print(f"[{datetime.now().strftime('%x %X')}] Label 列的值: {dict(sorted(df['Label'].value_counts().items()))}")

[06/04/25 17:39:32] 🚀 Loading cleaned data from /content/drive/MyDrive/NYIT/880/data/preprocessed/cleaned.parquet
[06/04/25 17:39:44] shape=(9320035, 79)
[06/04/25 17:39:44] Label 列的值: {0: 6573101, 1: 286191, 2: 611, 3: 230, 4: 686012, 5: 1730, 6: 576191, 7: 41508, 8: 461912, 9: 139890, 10: 10990, 11: 193354, 12: 160639, 13: 87, 14: 187589}
[06/04/25 17:39:47] 🧹 检测到全零值的列共 8 个: [32, 34, 56, 57, 58, 59, 60, 61]
[06/04/25 17:39:50] 🧹 删除后 shape=(9320035, 71)
[06/04/25 17:39:53] ⚠️ 处理小于 -1 的值. 共找到 15 行包含小于 -1 的值
[06/04/25 17:39:56] ⚠️ 处理后 shape=(9320020, 71)
[06/04/25 17:39:56] ⚠️ Label 列的值: {0: 6573086, 1: 286191, 2: 611, 3: 230, 4: 686012, 5: 1730, 6: 576191, 7: 41508, 8: 461912, 9: 139890, 10: 10990, 11: 193354, 12: 160639, 13: 87, 14: 187589}
[06/04/25 17:40:08] ✂️ 裁剪后 shape=(4746934, 71)
[06/04/25 17:40:08] ✂️ Label 列的值: {0: 2000000, 1: 286191, 2: 611, 3: 230, 4: 686012, 5: 1730, 6: 576191, 7: 41508, 8: 461912, 9: 139890, 10: 10990, 11: 193354, 12: 160639, 13: 87, 14: 187589}
[06/04/25

## 数据统计
- shape
- 每种标签的样本数
- 特征列的数值分布
- 负值统计

In [None]:
report = f"[{datetime.now().strftime('%x %X')}] === Data Summary ==="

report_file = preprocessed_folder / 'trimed_report.txt'

pd.set_option('display.max_rows', None)     # 显示所有行（每一列对应一行）
pd.set_option('display.max_columns', None)  # 显示所有列（每一个统计指标）

report += "\n" + f"[{datetime.now().strftime('%x %X')}] 数据集形状: {df.shape}"
report += "\n" + f"[{datetime.now().strftime('%x %X')}] 各标签样本统计:\n{dict(sorted(df['Label'].value_counts().items()))}"
report += "\n" + f"[{datetime.now().strftime('%x %X')}] 数据分布统计: \n{df.describe()}"

negative_counts = (df < 0).sum()
negative_counts_filtered = {
    k: int(v) for k, v in negative_counts.items() if v > 0
}
report += "\n"*2 + f"[{datetime.now().strftime('%x %X')}] 负值统计, 共 [{len(negative_counts_filtered)}] 列存在负值:"
report += "\n" + f"{negative_counts_filtered}"

print(f"[{datetime.now().strftime('%x %X')}] 统计数据输出到文件: {report_file}")
with open(report_file, 'w') as f:
    f.write(report)

[06/04/25 17:40:48] 统计数据输出到文件: /content/drive/MyDrive/NYIT/880/data/preprocessed/trimed_report.txt


## Dataset split
- train set: 80%
- valid set: 10%
- test set: 10%

In [None]:
from sklearn.model_selection import train_test_split

if splits_folder.exists():
    print(f"[{datetime.now().strftime('%x %X')}] ⚠️ {splits_folder} 目录已存在, 将覆盖里面的数据")
else:
    print(f"[{datetime.now().strftime('%x %X')}] 📁 创建目录: {splits_folder}")
    splits_folder.mkdir(parents=True, exist_ok=True)

# 分离 特征矩阵 与 标签
features = df.columns.difference(['Label'])
X = df[features].to_numpy(dtype=np.float32)
y = df['Label'].to_numpy(dtype=np.int32)

print(f'Original features: {X.shape}')
print(f'Original labels: {get_label_counts(y)}\n')

# Split whole => 0.8 : 0.2
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
del X, y

# Split 0.2 => 0.1 : 0.1
X_valid, X_test, y_valid, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42, stratify=y_test)

print(f'X_train shape: {X_train.shape}')
print(f'y_train labels: {get_label_counts(y_train)}\n')
np.save(splits_folder / f'train_X.npy', X_train)
np.save(splits_folder / f'train_y.npy', y_train)

print(f'X_valid shape: {X_valid.shape}')
print(f'y_valid labels: {get_label_counts(y_valid)}\n')
np.save(splits_folder / f'valid_X.npy', X_valid)
np.save(splits_folder / f'valid_y.npy', y_valid)

print(f'X_test shape: {X_test.shape}')
print(f'y_test labels: {get_label_counts(y_test)}\n')
np.save(splits_folder / f'test_X.npy', X_test)
np.save(splits_folder / f'test_y.npy', y_test)

print(f'[{datetime.now().strftime("%x %X")}] ✅ Saved splited datasets to {splits_folder}/ (train, valid, test)')

[06/04/25 18:00:36] ⚠️ /content/drive/MyDrive/NYIT/880/data/preprocessed/splits 目录已存在, 将覆盖里面的数据
Original features: (4746934, 70)
Original labels: {0: 2000000, 1: 286191, 2: 611, 3: 230, 4: 686012, 5: 1730, 6: 576191, 7: 41508, 8: 461912, 9: 139890, 10: 10990, 11: 193354, 12: 160639, 13: 87, 14: 187589}

X_train shape: (3797547, 70)
y_train labels: {0: 1600000, 1: 228953, 2: 489, 3: 184, 4: 548809, 5: 1384, 6: 460953, 7: 33206, 8: 369530, 9: 111912, 10: 8792, 11: 154683, 12: 128511, 13: 70, 14: 150071}

X_valid shape: (474693, 70)
y_valid labels: {0: 200000, 1: 28619, 2: 61, 3: 23, 4: 68601, 5: 173, 6: 57619, 7: 4151, 8: 46191, 9: 13989, 10: 1099, 11: 19335, 12: 16064, 13: 9, 14: 18759}

X_test shape: (474694, 70)
y_test labels: {0: 200000, 1: 28619, 2: 61, 3: 23, 4: 68602, 5: 173, 6: 57619, 7: 4151, 8: 46191, 9: 13989, 10: 1099, 11: 19336, 12: 16064, 13: 8, 14: 18759}

[06/04/25 18:01:29] ✅ Saved splited datasets to /content/drive/MyDrive/NYIT/880/data/preprocessed/splits/ (train, vali