In [1]:
from esfMRI import sliceWindows, joint_cluster_save_states, plot_sates, align, clustering_evaluate, step_evaluate
from sklearn import cluster, metrics
from nilearn import connectome
import numpy as np
import pickle
import math
import os

### 猜测聚类簇数

1. 肘点法：绘制inertia随k值变化的曲线，转折幅度最大的点作为簇数。

### 评估聚类质量

对于不存在已知分类的评价，只能采用内部评价指标  
基础参数有
1. 紧密度（Compactness）
2. 分割度（Seperation）
3. 误差平方和（SSE: Sum of squares of errors）

评价指标
1. 轮廓系数 —— 越大越好
2. Calinski-Harabasz Index（CH） —— 越大越好
3. Davies-Bouldin Index（DB） —— 越小越好

### 模型评估方法

1. 贝叶斯信息准则（BIC） —— 越小越好
2. 赤池信息量准则（AIC）

In [2]:
# 可调节参数
window_length_Second = [60, 120, 150, 180]
sliding_step = 10
target_states = [5, 6, 8]

### 基于状态的方法的动态指标(state-based):

**dFC强度（dFC strength）**：FC在给定状态中的强度。  
**停留时间（Dwell time）**：受试在每个状态中停留的平均时间。  
**占用率（Occupancy rate）**：扫描期间每个状态发生的时间百分比。  
**转换概率矩阵（Transition matrix）**：从一种状态转换到另一种状态的概率。  
**平均可变性指数（Average variability index）**：它表示功能源的整体动态水平。可变性指数定义为二项分布的标准差，并估计一个区域与给定源的关联中的可变性水平。  
**功能（域间）状态连接（Functional (inter-domain) state connectivity）**：当一种技术（例如空间动态层次）分别估计每个源的动态状态时，它可以捕捉不同源（例如功能域）状态之间的并发性。  

In [3]:
# 导入时间序列
with open("time_series2.pkl", "rb") as f:
    data = pickle.load(f)

In [5]:
# 拼接后聚类评估步长影响
for subid in data:
    for time in window_length_Second:
        # if time != 150:
        #     continue
        for k in target_states:
            save_dir = f"cluster_evaluate/joint/{subid}/{k}_states"
            os.makedirs(save_dir, exist_ok=True)
            step_evaluate(time, range(1, 10), k, save_dir, data[subid]["ses-preop"], data[subid]["ses-postop"])

In [4]:
# 拼接后聚类评估
for subid in data:
    for time in window_length_Second:
        if time != 150:
            continue
        windows_preop = []
        windows_postop = []
        for run, items in data[subid]["ses-preop"].items():
            intervalFrame = math.ceil(sliding_step/items["TR"])
            preopFrame = math.ceil(time/items["TR"])
            tmp = sliceWindows(items["time_series"], preopFrame, intervalFrame)
            windows_preop += tmp
        for run, items in data[subid]["ses-postop"].items():
            intervalFrame = math.ceil(sliding_step/items["TR"])
            postopFrame = math.ceil(time/items["TR"])
            tmp = sliceWindows(items["time_series"], postopFrame, intervalFrame)
            windows_postop += tmp
        save_dir = f"cluster_evaluate/joint/{subid}/"
        os.makedirs(save_dir, exist_ok=True)
        clustering_evaluate(windows_preop, range(2, 15), f"{save_dir}/preop_{time}.png")
        clustering_evaluate(windows_postop, range(2, 15), f"{save_dir}/postop_{time}.png")
        clustering_evaluate(windows_preop+windows_postop, range(2, 15), f"{save_dir}/total_{time}.png")

In [None]:
# 拼接后聚类，输出状态变化
for subid in data:
    states = {}
    for time in window_length_Second:
        # if time != 150:
        #     continue
        states[time] = {}
        preopFrame = math.ceil(time/data[subid]["ses-preop"]["run-01"]["TR"])
        postopFrame = math.ceil(time/data[subid]["ses-postop"]["run-01"]["TR"])
        windows_preop = []
        windows_postop = []
        window_length_preop = []
        window_length_postop = []
        for run, items in data[subid]["ses-preop"].items():
            tmp = sliceWindows(items["time_series"], preopFrame, 1)
            window_length_preop.append(len(tmp))
            windows_preop += tmp
        for run, items in data[subid]["ses-postop"].items():
            tmp = sliceWindows(items["time_series"], postopFrame, 1)
            window_length_postop.append(len(tmp))
            windows_postop += tmp
        states[time]["length"] = [window_length_preop, window_length_postop]
        for k in target_states:
            states_k = joint_cluster_save_states(windows_preop, windows_postop, k)
            states[time]["preop"] = states_k[0]
            states[time]["postop"] = states_k[1]
            save_dir = f"states_pkl/joint/{subid}"
            os.makedirs(save_dir, exist_ok=True)
            with open(f"{save_dir}/{k}states.pkl", "wb") as f:
                pickle.dump(states, f)

            save_dir = f"states/joint/{k}states/{time}/{subid}/"
            os.makedirs(save_dir, exist_ok=True)
            plot_sates(states[time]["preop"], f"{save_dir}/preop.png")
            plot_sates(states[time]["postop"], f"{save_dir}/postop.png")
            tmp = 0
            for run,length in enumerate(states[time]["length"][0]):
                plot_sates(states[time]["preop"][tmp:tmp+length], f"{save_dir}/preop_run{run+1:0>2d}.png")
                tmp += length
            tmp = 0
            for run,length in enumerate(states[time]["length"][1]):
                plot_sates(states[time]["postop"][tmp:tmp+length], f"{save_dir}/postop_run{run+1:0>2d}.png")
                tmp += length

In [None]:
# 平均后聚类
align_length_preop = 130
align_length_postop = 200
for subid in data:
    states = {}

    time_series_preop = None
    time_series_postop = None
    count = 0
    for run, items in data[subid]["ses-preop"].items():
        if items["time_series"].shape[0] < align_length_preop:
            continue
        count += 1
        time_series_preop = align(items["time_series"], align_length_preop) if time_series_preop is None else time_series_preop + align(items["time_series"], align_length_preop)
    if time_series_preop is None:
        print(subid, "pre")
        continue
    time_series_preop = time_series_preop/count
    count = 0
    for run, items in data[subid]["ses-postop"].items():
        if items["time_series"].shape[0] < align_length_postop:
            continue
        count += 1
        time_series_postop = align(items["time_series"], align_length_postop) if time_series_postop is None else time_series_postop + align(items["time_series"], align_length_postop)
    if time_series_postop is None:
        print(subid, "post")
        continue
    time_series_postop = time_series_postop/count

    for time in window_length_Second:
        states[time] = {}
        preopTR = math.ceil(time/data[subid]["ses-preop"]["run-01"]["TR"])
        postopTR = math.ceil(time/data[subid]["ses-postop"]["run-01"]["TR"])
        windows_preop = sliceWindows(time_series_preop, preopTR, 1)
        windows_postop = sliceWindows(time_series_postop, postopTR, 1)
        for k in target_states:
            if len(windows_preop) <= k and len(windows_postop) <= k:
                continue
            states_k = joint_cluster_save_states(windows_preop, windows_postop, k)
            states[time]["preop"] = states_k[0]
            states[time]["postop"] = states_k[1]
            save_dir = f"states_pkl/average/{subid}"
            os.makedirs(save_dir, exist_ok=True)
            with open(f"{save_dir}/{k}states.pkl", "wb") as f:
                pickle.dump(states, f)

            save_dir = f"states/average/{k}states/{time}/{subid}/"
            os.makedirs(save_dir, exist_ok=True)
            plot_sates(states[time]["preop"], f"{save_dir}/preop.png")
            plot_sates(states[time]["postop"], f"{save_dir}/postop.png")

In [3]:
with open("slidingWindows.pkl", "rb") as f:
    slidingWindows = pickle.load(f)

In [4]:
# 拼接后聚类评估
windows_preop = []
windows_postop = []
save_path = "cluster_evaluate/total"
os.makedirs(save_path, exist_ok=True)
time = 180
for subid in slidingWindows:
        windows_preop += slidingWindows[subid]["ses-preop"]["total"]
        windows_postop += slidingWindows[subid]["ses-postop"]["total"]
clustering_evaluate(windows_preop, range(2, 15), f"{save_path}/{time}_preop.png")
clustering_evaluate(windows_postop, range(2, 15), f"{save_path}/{time}_postop.png")
clustering_evaluate(windows_preop+windows_postop, range(2, 15), f"{save_path}/{time}_total.png")

In [3]:
with open("dFCs.pkl", "rb") as f:
    dFCs = pickle.load(f)

In [9]:
# 全体拼接后聚类，输出状态变化
save_path = "states/total"
os.makedirs(save_path, exist_ok=True)
for time in window_length_Second:
    if time != 180:
        continue
    for k in target_states:
        dfcs_preop = None
        dfcs_postop = None
        for subid in dFCs:
            # preop
            if "total" in dFCs[subid]["ses-preop"]:
                if dfcs_preop is None:
                    dfcs_preop = dFCs[subid]["ses-preop"]["total"].reshape((dFCs[subid]["ses-preop"]["total"].shape[0], 13456))
                else:
                    dfcs_preop = np.vstack((dfcs_preop, dFCs[subid]["ses-preop"]["total"].reshape((dFCs[subid]["ses-preop"]["total"].shape[0], 13456))))
            else:
                for run in dFCs[subid]["ses-preop"]:
                    if dfcs_preop is None:
                        dfcs_preop = dFCs[subid]["ses-preop"][run].reshape((dFCs[subid]["ses-preop"][run].shape[0], 13456))
                    else:
                        dfcs_preop = np.vstack((dfcs_preop, dFCs[subid]["ses-preop"][run].reshape((dFCs[subid]["ses-preop"][run].shape[0], 13456))))
            # postop
            if "total" in dFCs[subid]["ses-postop"]:
                if dfcs_postop is None:
                    dfcs_postop = dFCs[subid]["ses-postop"]["total"].reshape((dFCs[subid]["ses-postop"]["total"].shape[0], 13456))
                else:
                    dfcs_postop = np.vstack((dfcs_postop, dFCs[subid]["ses-postop"]["total"].reshape((dFCs[subid]["ses-postop"]["total"].shape[0], 13456))))
            else:
                for run in dFCs[subid]["ses-postop"]:
                    if dfcs_postop is None:
                        dfcs_postop = dFCs[subid]["ses-postop"][run].reshape((dFCs[subid]["ses-postop"][run].shape[0], 13456))
                    else:
                        dfcs_postop = np.vstack((dfcs_postop, dFCs[subid]["ses-postop"][run].reshape((dFCs[subid]["ses-postop"][run].shape[0], 13456))))
        km = cluster.KMeans(k)
        km.fit(np.vstack((dfcs_preop, dfcs_postop)))
        save_dir = f"{save_path}/cluster"
        os.makedirs(save_dir, exist_ok=True)
        with open(f"{save_dir}/km_{time}s_{k}states.pkl", "wb") as f:
            pickle.dump(km, f)

        states = {}
        for subid in dFCs:
            states[subid] = {}
            save_dir = f"{save_path}/{k}states/{time}/{subid}"
            os.makedirs(save_dir, exist_ok=True)
            for run, items in dFCs[subid]["ses-preop"].items():
                if run == "total":
                    continue
                states[subid][run] = km.predict(items.reshape((items.shape[0], 13456)))
                plot_sates(states[subid][run], f"{save_dir}/preop_{run}.png")
            for run, items in dFCs[subid]["ses-postop"].items():
                if run == "total":
                    continue
                states[subid][run] = km.predict(items.reshape((items.shape[0], 13456)))
                plot_sates(states[subid][run], f"{save_dir}/postop{run}.png")
        save_dir = f"{save_path}/pkl"
        os.makedirs(save_dir, exist_ok=True)
        with open(f"{save_path}/pkl/{time}_{k}.pkl", "wb") as f:
            pickle.dump(states, f)