用户可定义IOH模块

In [5]:
# from numba import njit
import numpy as np
# @njit
def Check_If_IOH(time_series, srate, IOH_value, duration):
    """
    Check if there is a period of intraoperative hypotension (IOH) in the time series.

    Parameters:
    - time_series (1D array-like): The blood pressure time series.
    - srate: Sampling rate of the time series (samples per second).
    - IOH_value: Threshold value for IOH (blood pressure below this is considered hypotensive).
    - duration: duration in seconds that defines IOH (must stay below IOH_value for this period).

    Returns:
    - bool: True if IOH is detected, otherwise False.
    """
    # 确保 time_series 是 numpy array
    if isinstance(time_series, list):
        time_series = np.array(time_series)

    # 将Duration转换为采样点数
    duration_samples = int(duration * srate)
    
    # 如果时间序列长度小于duration_samples，不可能满足IOH条件，直接返回False
    if len(time_series) < duration_samples:
        return False
    
    # 创建一个布尔掩码数组，标记低于IOH阈值的点
    below_threshold = time_series < IOH_value
    
    # 使用滑动窗口检查是否存在连续的duration_samples个值都低于IOH_value
    for i in range(len(below_threshold) - duration_samples + 1):
        # 检查当前滑动窗口内的所有值是否都为True（即都低于IOH_value）
        if np.all(below_threshold[i:i + duration_samples]):
            return True
    
    return False

def sliding_window_average(time_series, slide_len):
    if slide_len <= 0:
        raise ValueError("slide_len must be greater than 0")
    
    # 存储滑动窗口的平均值
    window_averages = []
    
    # 遍历序列，按滑动窗口大小取值
    for i in range(0, len(time_series), slide_len):
        # 获取当前窗口的值
        window = time_series[i:i + slide_len]
        # 计算窗口的平均值并存储
        window_avg = round(np.nanmean(window), 2)
        window_averages.append(window_avg)
    
    return window_averages


# 只适用s为单位的抽样,且 slide_len必须能被stime整除，以及Duration能被slide_len整除，不然求出来的序列有误差。
def Check_If_IOH_Combined_S(time_series, stime, IOH_value, Duration, slide_len):
    # Duration 和 滑动窗口长度转为采样点
    duration_samples = int(Duration / slide_len)
    slide_samples = slide_len
    
    # 计算滑动窗口的平均值
    if slide_samples == 1:
        smoothed_series = time_series
    else:
        # smoothed_series = Count_Windows_MovingAvg(time_series, slide_samples)
        smoothed_series = sliding_window_average(time_series, slide_samples)

    # Step 2: 对平滑后的序列进行低血压检测
    if duration_samples == 1:
        evt = np.nanmax(smoothed_series) < IOH_value
    else:
        # Step 3: 逐点判断
        evt = Check_If_IOH(smoothed_series, 1, IOH_value, duration_samples)
        # evt = np.nanmax(sliding_window_value) < IOH_value

    # print("evt:", evt, "max:", np.nanmax(sliding_window_count) )
    return evt

def round_data_to_two_decimals(data):
    return np.where(np.isnan(data), data, np.round(data, 2))

def user_definable_IOH(time_series):
    predcition_lables = []
    for i in time_series:
        predcition_lables.append(Check_If_IOH_Combined_S(i, 1, 65, 30, 1))
    return predcition_lables
    

## VitalDB低血压数据数据切片和重组：

In [6]:
import pandas as pd
import numpy as np
from collections import defaultdict
def get_batched_data_fn(
    batch_size: int = 128, 
    # context_len: int = 120, 
    # horizon_len: int = 24,
):
    # 读取CSV文件
    csv_file = '/home/likx/time_series_forecasting/IOH_Datasets_Preprocess/vitaldb/vitaldb_test_data.csv'
    data = pd.read_csv(csv_file)

    # # 使用前向填充和后向填充处理NaN值
    # data = data.ffill().bfill()

    # 数据预处理前总数据
    print("源数据长度：", len(data))
    # 统计处理后 examples 中 label 列的分布
    label_counts = data['label'].value_counts(normalize=True) * 100
    print("处理前的Label分布 (%):")
    print(label_counts)

    # 移除 prediction_mbp 为 '[]' 的行
    # data = data.loc[data['prediction_mbp'] != '[]']

    # 定义处理序列数据的函数，直接通过空格拆分并转换为浮点数列表，且完成重采样
    def parse_sequence(sequence_str, skip_rate=0, sample_type='avg_sample'):
        try:
            sequence_list = sequence_str.split()
            sequence_array = np.array([np.nan if x == 'nan' else float(x) for x in sequence_list])
            mean_value = round(np.nanmean(sequence_array), 2)
            sequence_array_filled = np.where(np.isnan(sequence_array), mean_value, sequence_array)
            if np.any(np.isnan(sequence_array_filled)):
                return [] 
                    
            if skip_rate > 0: # 如果需要重采样
                if sample_type == 'skip_sample':
                    sequence_array_filled = sequence_array_filled[::skip_rate]
                elif sample_type == 'avg_sample': #默认按平均值进行采样
                    sequence_array_filled = sliding_window_average(sequence_array_filled, skip_rate)

            return sequence_array_filled
        except ValueError:
            return [] 
    # 初始化 defaultdict
    examples = defaultdict(list)

    for index, row in data.iterrows():
        bts = parse_sequence(row['bts'][1:-1], skip_rate=0, sample_type='skip_sample') #采样周期是：2*skip_rate
        hrs = parse_sequence(row['hrs'][1:-1], skip_rate=0, sample_type='skip_sample')
        dbp = parse_sequence(row['dbp'][1:-1], skip_rate=0, sample_type='skip_sample')
        mbp = parse_sequence(row['mbp'][1:-1], skip_rate=0, sample_type='skip_sample')
        prediction_mbp = parse_sequence(row['prediction_mbp'][1:-1], skip_rate=0, sample_type='skip_sample')
        # print(len(bts), len(hrs), len(dbp), len(mbp), len(prediction_mbp))
        if len(bts) != 450 or len(hrs) != 450 or len(dbp) != 450 or\
            len(mbp) != 450 or len(prediction_mbp) != 150:
            continue
        
        examples['caseid'].append(row['caseid'])
        examples['stime'].append(row['stime'])
        examples['ioh_stime'].append(row['ioh_stime'])
        examples['ioh_dtime'].append(row['ioh_dtime'])
        examples['age'].append(row['age']) # np.full(len(bts), row['age'])
        examples['sex'].append(row['sex'])
        examples['bmi'].append(row['bmi'])
        examples['label'].append(Check_If_IOH_Combined_S(prediction_mbp, 1, 65, 30, 1))
        examples['bts'].append(bts)
        examples['hrs'].append(hrs)
        examples['dbp'].append(dbp)
        examples['inputs'].append(mbp)
        examples['outputs'].append(prediction_mbp)

    # 修正统计处理后的样本数量
    print("处理后的测试样本数量:", len(examples['caseid']))

    # 统计处理后 examples 中 label 列的分布
    label_counts = pd.Series(examples['label']).value_counts(normalize=True) * 100
    print("处理后的Label分布 (%):")
    print(label_counts)
# examples

    '''
        (num_examples - 1) // batch_size：这是整除运算，用来计算可以完全容纳 batch_size 个样本的完整批次数量。
        1 +：确保即使在最后的批次不足 batch_size 的时候，仍然会生成这个未满的批次。
    '''
    def data_fn(): # 批次生成器函数
        for i in range(1 + (len(data) - 1) // batch_size):
            yield {k: v[(i * batch_size) : ((i + 1) * batch_size)] for k, v in examples.items()}
    
    return data_fn

In [7]:
from sklearn.metrics import auc, classification_report, confusion_matrix, accuracy_score, roc_curve, roc_auc_score, f1_score, precision_recall_curve

# Define metrics
def mse(y_pred, y_true):
  y_pred = np.array(y_pred)
  y_true = np.array(y_true)
  return np.mean(np.square(y_pred - y_true), axis=1, keepdims=True)

def mae(y_pred, y_true):
  y_pred = np.array(y_pred)
  y_true = np.array(y_true)
  return np.mean(np.abs(y_pred - y_true), axis=1, keepdims=True)

## TimesFM模型加载与测试

In [8]:

import timesfm
import pandas as pd
import numpy as np
from collections import defaultdict
import time

context_len = 450
horizon_len = 150

# 数据集加载与批量化
bs = 128
input_data = get_batched_data_fn(batch_size=bs)

# Loading TimesFM in pytorch version
tfm = timesfm.TimesFm(
      hparams=timesfm.TimesFmHparams(
          backend="gpu",
          per_core_batch_size=32,      
          horizon_len=horizon_len,
      ),
      checkpoint=timesfm.TimesFmCheckpoint(
         version="torch",
          # huggingface_repo_id="google/timesfm-1.0-200m-pytorch"
          path="/home/data/times-forecasting/checkpoints/timesfm-1.0-200m-pytorch/torch_model.ckpt"),
  )
print("Loading Model Finish.")

# Loading TimesFM in JAX/PAX version
# tfm = timesfm.TimesFm(
#       hparams=timesfm.TimesFmHparams(
#           backend="gpu",
#           per_core_batch_size=32,      
#           horizon_len=horizon_len,
#       ),
#       checkpoint=timesfm.TimesFmCheckpoint(
#          version="jax",
#          step=1100000,
#          path="/home/likx/time_series_forecasting/datasets_and_checkpoints/timesfm-1.0-200m/checkpoints/"),
#   )
# print("Loading Model Finish.")


# Benchmark
metrics = defaultdict(list)
ground_true_labels = []
for i, example in enumerate(input_data()):
    if np.array(example["inputs"]).shape != (bs, context_len):
        continue

    raw_forecast, _ = tfm.forecast(
        inputs=example["inputs"], freq=[0] * len(example["inputs"])
    )
    start_time = time.time()
    # Forecast with covariates
    # Output: new forecast, forecast by the xreg
    cov_forecast, ols_forecast = tfm.forecast_with_covariates(  
        inputs=example["inputs"],
        dynamic_numerical_covariates={
            #   "body_temperature": example["bts"],
            #   "heart_rate": example["hrs"],
            #   "diastolic_blood_pressure": example["dbp"],
        },
        dynamic_categorical_covariates={},
        static_numerical_covariates={
            "age": example["age"],
            "body_mass_index": example["bmi"],
        },
        static_categorical_covariates={
            "gender": example["sex"],
        },
        freq=[0] * len(example["inputs"]),
        xreg_mode="xreg + timesfm",              # default
        ridge=0.0,
        force_on_cpu=False,
        normalize_xreg_target_per_input=True,    # default
    )

    print(
        f"\rFinished batch {i} linear in {time.time() - start_time} seconds",
        end="",
    )
    # print()
    raw_forecast = raw_forecast[:, :horizon_len]
    true_series = np.array(example["outputs"])[:, :horizon_len]
    ground_true_labels.extend(example["label"])

    metrics["eval_mae_timesfm"].extend(mae(raw_forecast, true_series))
    metrics["eval_mae_xreg_timesfm"].extend(mae(cov_forecast, true_series))
    metrics["eval_mae_xreg"].extend(mae(ols_forecast, true_series))
    metrics["eval_mse_timesfm"].extend(mse(raw_forecast[:, :horizon_len], true_series))
    metrics["eval_mse_xreg_timesfm"].extend(mse(cov_forecast, true_series))
    metrics["eval_mse_xreg"].extend(mse(ols_forecast, true_series))
    metrics["eval_pred_lable_timesfm"].extend(user_definable_IOH(raw_forecast))
    metrics["eval_pred_lable_xreg_timesfm"].extend(user_definable_IOH(cov_forecast))
    metrics["eval_pred_lable_xreg"].extend(user_definable_IOH(ols_forecast))

print()

for k, v in metrics.items():
  if k in ["eval_pred_lable_timesfm", "eval_pred_lable_xreg_timesfm", "eval_pred_lable_xreg"]:
    print(k, "--Prediction Results:")
    precision, recall, thmbps = precision_recall_curve(ground_true_labels, v)
    auprc = auc(recall, precision)

    fpr, tpr, thmbps = roc_curve(ground_true_labels, v)
    auroc = auc(fpr, tpr)
    f1 = f1_score(ground_true_labels, v)
    acc = accuracy_score(ground_true_labels, v)
    tn, fp, fn, tp = confusion_matrix(ground_true_labels, v).ravel()

    testres = 'auroc={:.3f}, auprc={:.3f} acc={:.3f}, F1={:.3f}, PPV={:.1f}, NPV={:.1f}, TN={}, fp={}, fn={}, TP={}'.format(auroc, auprc, acc, f1, tp/(tp+fp)*100, tn/(tn+fn)*100, tn, fp, fn, tp)
    print(testres)
  else:   
    print(f"{k}: {np.mean(v)}")

源数据长度： 14513
处理前的Label分布 (%):
label
False    87.900503
True     12.099497
Name: proportion, dtype: float64


  mean_value = round(np.nanmean(sequence_array), 2)


处理后的测试样本数量: 13073
处理后的Label分布 (%):
False    92.342997
True      7.657003
Name: proportion, dtype: float64
Loading Model Finish.
Finished batch 101 linear in 0.34670519828796387 seconds
eval_mae_timesfm: 5.582376960559375
eval_mae_xreg_timesfm: 10.527234063120654
eval_mae_xreg: 83.0502415494033
eval_mse_timesfm: 86.06718922410472
eval_mse_xreg_timesfm: 216.33579340607128
eval_mse_xreg: 7089.462976952161
eval_pred_lable_timesfm --Prediction Results:
auroc=0.793, auprc=0.631 acc=0.941, F1=0.616, PPV=61.3, NPV=96.8, TN=11663, fp=392, fn=381, TP=620
eval_pred_lable_xreg_timesfm --Prediction Results:
auroc=0.615, auprc=0.305 acc=0.877, F1=0.276, PPV=25.2, NPV=94.1, TN=11146, fp=909, fn=695, TP=306
eval_pred_lable_xreg --Prediction Results:
auroc=0.500, auprc=0.538 acc=0.077, F1=0.142, PPV=7.7, NPV=nan, TN=0, fp=12055, fn=0, TP=1001


  testres = 'auroc={:.3f}, auprc={:.3f} acc={:.3f}, F1={:.3f}, PPV={:.1f}, NPV={:.1f}, TN={}, fp={}, fn={}, TP={}'.format(auroc, auprc, acc, f1, tp/(tp+fp)*100, tn/(tn+fn)*100, tn, fp, fn, tp)
