In [1]:
import os
import sys
# del os.environ['MKL_NUM_THREADS'] # error corrected by MH 10/12/2022 (add these three lines)
import torch
from torch.autograd import Variable
import torch.utils.data as utils
import numpy as np
import gc
import sys
import scipy.io as spio
import h5py
import RESNET152_ATT_naive

In [2]:
def loadmat(filename):
    '''
    读取 MATLAB v7.3 `.mat` 文件（Whole_tracks 作为 tracks）
    '''
    output = dict()
    
    # 打开 HDF5 MAT 文件
    with h5py.File(filename, 'r') as data:
        # 读取 Whole_tracks 变量
        if 'Whole_tracks' not in data:
            raise KeyError("❌ 错误: 'Whole_tracks' 变量不存在！")

        whole_tracks = data['Whole_tracks']  # 结构体 Whole_tracks

        # 确保它有 `count` 和 `data`
        if 'count' not in whole_tracks or 'data' not in whole_tracks:
            raise KeyError(f"❌ 错误: 'Whole_tracks' 结构不完整！包含: {list(whole_tracks.keys())}")

        # 读取 count（可能是字符编码格式，需要解析）
        count = whole_tracks['count'][()]  
        print("🔍 Whole_tracks['count'] 数据:", count)
        print("🔍 数据类型:", type(count))

        # 直接转换成整数
        total_count = int(count.item())
        print(f'total_count: {total_count}')
        # 读取 Whole_tracks['data']
        track = []
        for i in range(total_count):
            data_ref = whole_tracks['data'][i].item()
            track.append(np.transpose(data[data_ref][:]).astype(np.float32))

        # 组织输出
        output['tracks'] = {
            'count': total_count,
            'data': track
        }
    
    return output

def _check_keys(dict):
    '''
    checks if entries in dictionary are mat-objects. If yes
    todict is called to change them to nested dictionaries
    '''
    for key in dict:
        if isinstance(dict[key], spio.matlab.mio5_params.mat_struct):
            dict[key] = _todict(dict[key])
    return dict        

def _todict(matobj):
    '''
    A recursive function which constructs from matobjects nested dictionaries
    '''
    dict = {}
    for strg in matobj._fieldnames:
        elem = matobj.__dict__[strg]
        if isinstance(elem, spio.matlab.mio5_params.mat_struct):
            dict[strg] = _todict(elem)
        else:
            dict[strg] = elem
    return dict

#%%
def mySoftmax(z):
    assert len(z.shape) == 2
    s = np.max(z, axis=1)
    s = s[:, np.newaxis] # necessary step to do broadcasting
    e_x = np.exp(z - s)
    div = np.sum(e_x, axis=1)
    div = div[:, np.newaxis] # dito
    return e_x / div
"""normalize"""#110
def rescale(X_list,count):
    output=list()
    if count==1:
        output.append(X_list/110)
        return output
    for i in range(len(X_list)):
        output.append(X_list[i]/110)
    return output

def udflip(X_nparray, y_nparray, shuffle=True):

    if X_nparray.shape[2] == 4:
        if np.std(X_nparray[:, 0, :]) > np.std(X_nparray[:, -1, :]):
            print("Detected special info in first column, swapping...")
            X_nparray = np.concatenate((X_nparray[:, 1:, :], X_nparray[:, 0:1, :]), axis=1)
    
    X_flipped = np.flip(X_nparray, axis=2)  

    X_aug = np.vstack((X_nparray, X_flipped))
    y_aug = np.hstack((y_nparray, y_nparray))  

    if shuffle:
        shuffle_idx = np.random.permutation(X_aug.shape[0])
        return X_aug[shuffle_idx], y_aug[shuffle_idx]
    else:
        return X_aug, y_aug
def datato3d(arrays):#list of np arrays, NULL*3*100
    output=list()
    for i in arrays:
        i=np.squeeze(i,axis=1)
        i=np.transpose(i,(0,2,1))
        output.append(i)
    return output

In [None]:
import torch
import numpy as np
import h5py
import gc
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
import torch.utils.data as utils

# 参数
matpath = '../Testing_Set/J0037_tracks.mat'
label_path = '../Testing_Set/J0037_class_label.mat'  # 你的标签文件
classnum = 15  # 类别数
ROI_EMBEDDING_DIM = 32

def loadmat(filename):
    """ 读取 MATLAB v7.3 .mat 文件 """
    with h5py.File(filename, 'r') as data:
        if 'Whole_tracks' not in data:
            raise KeyError("❌ 错误: 'Whole_tracks' 变量不存在！")
        
        whole_tracks = data['Whole_tracks']
        if 'count' not in whole_tracks or 'data' not in whole_tracks:
            raise KeyError(f"❌ 错误: 'Whole_tracks' 结构不完整！包含: {list(whole_tracks.keys())}")

        # 读取 count
        count = int(whole_tracks['count'][()].item())
        track = [np.transpose(data[whole_tracks['data'][i].item()][:]).astype(np.float32) for i in range(count)]
    
    return {'tracks': {'count': count, 'data': track}}

def load_labels(label_path):
    """ 读取标签 .mat 文件 """
    with h5py.File(label_path, 'r') as data:
        if 'class_label' not in data:
            raise KeyError("❌ 错误: 'class_label' 变量不存在！")
        
        class_label = data['class_label'][()]
        
        if isinstance(class_label, np.ndarray):
            if class_label.size == 1:  
                class_label = class_label.item()
            else:  
                class_label = np.array(class_label)
        else:
            class_label = int(class_label)

        print(f"✅ 成功解析 class_label, 形状: {class_label.shape}")
        return class_label


""" 测试模型 """
args_test_batch_size = 10000
NCLASS = int(classnum)

print(f"📌 处理数据: {matpath}")
mat = loadmat(matpath)
X_test = mat['tracks']['data']
X_test = np.asarray(X_test).astype(np.float32)
X_test_original = np.transpose(X_test, (0, 2, 1))  # 维度转换


📌 处理数据: ../Testing_Set/J0037_tracks.mat


In [4]:
def udflip(X_nparray, y_nparray, shuffle=True):

    if X_nparray.shape[2] == 4:
        if np.std(X_nparray[:, 0, :]) > np.std(X_nparray[:, -1, :]):
            print("Detected special info in first column, swapping...")
            X_nparray = np.concatenate((X_nparray[:, 1:, :], X_nparray[:, 0:1, :]), axis=1)
    
    X_flipped = np.flip(X_nparray, axis=2)  
    y_nparray = y_nparray.flatten()
    X_aug = np.vstack((X_nparray, X_flipped))
    y_aug = np.hstack((y_nparray, y_nparray))  

    if shuffle:
        shuffle_idx = np.random.permutation(X_aug.shape[0])
        return X_aug[shuffle_idx], y_aug[shuffle_idx]
    else:
        return X_aug, y_aug
# 读取标签
y_test = load_labels(label_path)
y_test_list = y_test
print(X_test_original.shape)
print(y_test.shape)
X_test, y_test = udflip(X_test_original,y_test,shuffle=False)
print(X_test.shape)
print(y_test.shape)
y_test = torch.from_numpy(y_test.astype(np.int64))  # 确保是整数类型
X_test = torch.from_numpy(X_test)

kwargs = {'num_workers': 1, 'pin_memory': True}
tst_set = utils.TensorDataset(X_test, y_test)
tst_loader = utils.DataLoader(tst_set, batch_size=args_test_batch_size, shuffle=False, **kwargs)

✅ 成功解析 class_label, 形状: (234541, 1)
(234541, 4, 100)
(234541, 1)
(469082, 4, 100)
(469082,)


In [5]:
len(y_test)

469082

In [6]:
X_test.shape, y_test.shape

(torch.Size([469082, 4, 100]), torch.Size([469082]))

In [7]:
X_test[0].shape

torch.Size([4, 100])

In [8]:
def aug_at_test(probs,mode='max'):
    assert(len(probs)>0)
    if(mode=='max'):
        all_probs=np.vstack(probs)
        print(all_probs.shape)
        max_probs=np.amax(all_probs,axis=1).reshape((2,-1))#row 0: prob for first half, row 1: prob for flipped half
        max_idx=np.argmax(max_probs,axis=0)#should be 0/1
        test_sample_count=all_probs.shape[0]/2
        
        class_pred=np.argmax(all_probs,axis=1)
        final_pred=list()
        for i in range(max_idx.shape[0]):
            final_pred.append(class_pred[int(i+test_sample_count*max_idx[i])])#if 0, first half
        return final_pred
    if(mode=='mean'):
        all_probs=np.exp(np.vstack(probs))
        test_sample_count=int(all_probs.shape[0]/2)
        final_probs=all_probs[0:test_sample_count]+all_probs[test_sample_count:]
        final_pred=np.argmax(final_probs,axis=1)
        return final_pred.tolist()

# import numpy as np

# def aug_at_test(probs, mode='max'):
#     """
#     适用于 **无数据增强** 的版本：
#     - 直接选择 `argmax` 作为最终预测
#     - `mode='max'` 或 `mode='mean'` 影响不大，因为没有翻转数据
    
#     参数：
#     - probs: 模型输出的 logits 列表，每个 batch 存储一次输出
    
#     返回：
#     - final_pred: 预测类别列表
#     """
#     assert len(probs) > 0, "probs 为空，无法计算预测结果"

#     # 合并所有 batch
#     all_probs = np.vstack(probs)  # 形状: (N, num_classes)

#     # 直接取最大概率类别作为预测类别
#     final_pred = np.argmax(all_probs, axis=1)

#     return final_pred.tolist()


In [13]:
# 数据加载
import os
import sys
# del os.environ['MKL_NUM_THREADS'] # error corrected by MH 10/12/2022 (add these three lines)
from Embedding_layer import ROIFeatureExtractor
import torch
from torch.autograd import Variable
import torch.utils.data as utils
import numpy as np
import gc
import sys
import scipy.io as spio
import RESNET152_ATT_naive
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
import torch.nn as nn
from Util import focalLoss, preprocess_fiber_input
from clustering_layer_v2 import ClusterlingLayer
from klDiv import KLDivLoss

modelpath = 'save_small/focal_loss_and_cluster_loss_c_10.0_FE.model'
# 加载模型
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(f'device: {device}')
ROI_EMBEDDING_DIM = 32
NUM_ROI_CLASSES = 726 + 1
HIDDEN_DIM = 64
model=RESNET152_ATT_naive.resnet18(num_classes=NCLASS, input_ch=3+ROI_EMBEDDING_DIM)
# init ROI Embedding layer
roi_embedding_layer = nn.Embedding(NUM_ROI_CLASSES, ROI_EMBEDDING_DIM).to(device)
# init FE
roi_extractor = ROIFeatureExtractor(roi_embedding_layer, ROI_EMBEDDING_DIM, hidden_dim=HIDDEN_DIM).to(device)
roi_extractor.to(device)
model.to(device)

# 2️⃣ 加载权重
state_dict = torch.load(modelpath, map_location=device)
model.load_state_dict(state_dict)

model.eval()
log_testing_total_loss = 0.0
log_focal_loss = 0.0
log_centering_loss= 0.
log_clustering_loss = 0.0
probs = []
preds = []
labels = []
clustering_layer = ClusterlingLayer(embedding_dimension=512, num_clusters=NCLASS, alpha=1.0)
kl_loss = KLDivLoss(NCLASS, loss_weight=2.0, temperature=2)
kl_loss.to(device)
clustering_layer.to(device)
global global_cluster_rois  # Ensure global access to cluster anatomical profiles
loss_nll = nn.NLLLoss(size_average=True) # log-softmax applied in the network
with torch.no_grad():
    for data, target in tst_loader:
        labels += target.cpu().numpy().tolist()
        # if args.cuda:
        data, target = data.to(device), target.to(device)
        # target = target.squeeze(1)
        data, target = Variable(data), Variable(target)
        # print(data.shape)
        # print(target.shape)
        data_processed = preprocess_fiber_input(data, roi_extractor=roi_extractor, device=device, net_type='FE')
        output, embed, _, _, _, _, _, _, _, _, _ = model(data_processed)

        # Compute focal loss
        floss = focalLoss(output, target, loss_nll=loss_nll)
        total_loss = floss
        log_focal_loss += floss.item()
        # Compute center loss if enabled

        # Compute clustering loss if enabled
        clustering_out, x_dis = clustering_layer(embed)

        # Get predicted cluster labels
        tar_dist = ClusterlingLayer.create_soft_labels(target, NCLASS, temperature=2).to(target.device)
        loss_clust = 10 * kl_loss.kl_div_cluster(torch.log(clustering_out), tar_dist) / 1024

        total_loss += loss_clust
        log_clustering_loss += loss_clust.item()

        # Accumulate total test loss
        log_testing_total_loss += total_loss.item()
        probs.append(output.data.cpu().numpy())

# Compute final predictions using test-time augmentation
preds = aug_at_test(probs, mode='max')
num_batch = len(tst_loader) / 1024

# Compute evaluation metrics
conf_mat = confusion_matrix(y_test_list, preds)
precision, recall, f1, _ = precision_recall_fscore_support(y_test_list, preds, average='macro')

avg_testing_loss = log_testing_total_loss / num_batch
avg_clustering_loss = log_clustering_loss / num_batch
print('\tCenter loss: {:.4f}'.format(log_centering_loss / num_batch))
print('\tfocal loss: {:.4f}'.format(log_focal_loss/num_batch))
print(f'Test set avg loss: {avg_testing_loss:.4f}')

print(f'\tClustering loss: {avg_clustering_loss:.4f}')
print('Precision, Recall, macro F1:', precision, recall, f1)

device: cuda:1
using ROI with emb: 32


  return F.log_softmax(x), embed, x_att, x, out1, out2, out3, final_feat, out1_feat, out2_feat, out3_feat


(469082, 15)
	Center loss: 0.0000
	focal loss: 0.0065
Test set avg loss: 1185.4198
	Clustering loss: 1185.4134
Precision, Recall, macro F1: 0.9180087466607116 0.8928001436431507 0.9028530012430607


In [14]:
print("\n".join(sys.argv))
# checkArgc(sys.argv)
main()

/home/bohan/.conda/envs/deterministic-a-bridge/lib/python3.9/site-packages/ipykernel_launcher.py
--f=/home/bohan/.local/share/jupyter/runtime/kernel-v3954b1db88fadc1a3a24ceac916b91b22e4d04f93.json


NameError: name 'main' is not defined

In [None]:
import h5py

filename = "../Testing_Set/J0037_tracks.mat"

# 读取 MAT v7.3 文件
with h5py.File(filename, 'r') as f:
    print("MAT 文件中的变量:", list(f.keys()))  # 输出所有变量名称


In [None]:
with h5py.File(filename, 'r') as f:
    if 'tracks' in f:
        print("tracks 内部结构:", list(f['tracks'].keys()))
    else:
        print("❌ 错误: 'tracks' 不存在！")


In [None]:
with h5py.File(filename, 'r') as f:
    print("MAT 文件中的变量:", list(f.keys()))  # 顶层变量

    # 检查 Whole_tracks 里面的内容
    if 'Whole_tracks' in f:
        print("\n🔍 Whole_tracks 结构:")
        try:
            print("    子变量:", list(f['Whole_tracks'].keys()))  # 打印 Whole_tracks 内部结构
        except AttributeError:
            print("    Whole_tracks 不是结构体，可能是数组或标量")

In [None]:
import h5py
import numpy as np

def loadmat(filename):
    '''
    读取 MATLAB v7.3 `.mat` 文件（Whole_tracks 作为 tracks）
    '''
    output = dict()
    
    # 打开 HDF5 MAT 文件
    with h5py.File(filename, 'r') as data:
        # 读取 Whole_tracks 变量
        if 'Whole_tracks' not in data:
            raise KeyError("❌ 错误: 'Whole_tracks' 变量不存在！")

        whole_tracks = data['Whole_tracks']  # 结构体 Whole_tracks

        # 确保它有 `count` 和 `data`
        if 'count' not in whole_tracks or 'data' not in whole_tracks:
            raise KeyError(f"❌ 错误: 'Whole_tracks' 结构不完整！包含: {list(whole_tracks.keys())}")

        # 读取 count（可能是字符编码格式，需要解析）
        count = whole_tracks['count'][()]  
        print("🔍 Whole_tracks['count'] 数据:", count)
        print("🔍 数据类型:", type(count))

        # 直接转换成整数
        total_count = int(count.item())
        # 读取 Whole_tracks['data']
        track = []
        for i in range(total_count):
            data_ref = whole_tracks['data'][i].item()
            track.append(np.transpose(data[data_ref][:]).astype(np.float32))

        # 组织输出
        output['tracks'] = {
            'count': total_count,
            'data': track
        }
    
    return output

# 运行测试
filename = "../Testing_Set/J0037_tracks.mat"
tracks_data = loadmat(filename)
print("✅ 成功读取 Whole_tracks 数据！")
print(f"轨迹数: {tracks_data['tracks']['count']}")


In [None]:
X_test=tracks_data['tracks']['data']
X_test=rescale(X_test,int(tracks_data['tracks']['count']))
#X_test_ud=np.asarray(udflip(X_test,int(mat['tracks']['count']))).astype(np.float32)
X_test=np.asarray(X_test).astype(np.float32)
#X_test=np.vstack((X_test,X_test_ud))
#X_test=X_test.reshape((X_test.shape[0],1,X_test.shape[1],X_test.shape[2]))
#X_test=datato3d(X_test)[0]
X_test=np.transpose(X_test,(0,2,1))

In [None]:
X_test[0].shape

In [None]:
import h5py
import numpy as np

def loadmat(filename):
    '''
    读取 MATLAB v7.3 `.mat` 文件（Whole_tracks 作为 tracks）
    '''
    output = dict()
    
    # 打开 HDF5 MAT 文件
    with h5py.File(filename, 'r') as data:
        # 读取 Whole_tracks 变量
        if 'Whole_tracks' not in data:
            raise KeyError("❌ 错误: 'Whole_tracks' 变量不存在！")

        whole_tracks = data['Whole_tracks']  # 结构体 Whole_tracks

        # 确保它有 `count` 和 `data`
        if 'count' not in whole_tracks or 'data' not in whole_tracks:
            raise KeyError(f"❌ 错误: 'Whole_tracks' 结构不完整！包含: {list(whole_tracks.keys())}")

        # 读取 count（可能是字符编码格式，需要解析）
        count = whole_tracks['count'][()]  
        print("🔍 Whole_tracks['count'] 数据:", count)
        print("🔍 数据类型:", type(count))

        # 直接转换成整数
        total_count = int(count.item())
        # 读取 Whole_tracks['data']
        track = []
        for i in range(total_count):
            data_ref = whole_tracks['data'][i].item()
            track.append(np.transpose(data[data_ref][:]).astype(np.float32))

        # 组织输出
        output['tracks'] = {
            'count': total_count,
            'data': track
        }
    
    return output

# 运行测试
filename = "../Testing_Set/J0037_class_label.mat"
tracks_data = loadmat(filename)
print("✅ 成功读取 Whole_tracks 数据！")
print(f"轨迹数: {tracks_data['tracks']['count']}")


In [None]:
import h5py

filename = "../Testing_Set/J0037_class_label.mat"

with h5py.File(filename, 'r') as f:
    print("MAT 文件变量:", list(f.keys()))  # 列出顶层变量
    
    for key in f.keys():
        print(f"\n🔍 变量 '{key}' 结构:")
        try:
            print("    子变量:", list(f[key].keys()))  # 如果是 group，打印内部结构
        except AttributeError:
            print("    不是结构体，可能是数组或标量")


In [None]:
import h5py
import numpy as np

def load_label_mat(filename):
    """
    读取 MATLAB v7.3 (.mat) 文件中的 `class_label`
    """
    with h5py.File(filename, 'r') as data:
        if 'class_label' not in data:
            raise KeyError("❌ 错误: 'class_label' 变量不存在！")

        # 读取 class_label
        class_label = data['class_label'][()]

        # 解析数据
        if isinstance(class_label, np.ndarray):
            if class_label.size == 1:  # 只有一个值
                class_label = class_label.item()
            else:  # 多个值，转换为 NumPy 数组
                class_label = np.array(class_label)
        else:
            class_label = int(class_label)  # 可能是单个数值

        print(f"✅ 成功解析 class_label: {class_label}")
        return class_label

# 测试
filename = "../Testing_Set/J0037_class_label.mat"
labels = load_label_mat(filename)
print("📌 解析出的 labels:", labels)


In [55]:
import torch
import numpy as np
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support

def evaluate_model(model, test_loader, device='cuda'):
    """
    通用测试函数（无损失函数计算）
    - 适用于分类任务
    - 计算预测结果、混淆矩阵、Precision、Recall、F1

    参数：
    - model: 训练好的 PyTorch 模型
    - test_loader: PyTorch DataLoader (测试集)
    - device: 'cuda' or 'cpu'

    返回：
    - conf_matrix: 混淆矩阵
    - precision, recall, f1: 分类指标
    """
    model.eval()  # 设置模型为评估模式
    preds = []
    labels = []
    
    with torch.no_grad():  # 不计算梯度，加速推理
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            # 获取模型输出
            output = model(data)

            # 分类任务：获取预测类别
            if output.dim() > 1:  # 确保 output 是 logits 形式
                pred = output.argmax(dim=1)  # 取最大概率的类别
                preds.extend(pred.cpu().numpy())
                labels.extend(target.cpu().numpy())

    # 计算分类指标
    if preds:
        conf_matrix = confusion_matrix(labels, preds)
        precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    else:
        conf_matrix, precision, recall, f1 = None, None, None, None

    # 输出结果
    print(f'📊 Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}')
    print(f'🔢 混淆矩阵:\n{conf_matrix}')

    return conf_matrix, precision, recall, f1
