In [1]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from network.model_unet_a_2d import *
from loss_utils import *
from data_loader import *
from data_augmentation import *
from test_utils import model_predict, dataset_eval
from torchinfo import summary
import random
import torch.backends.cudnn as cudnn

In [2]:
from copy import deepcopy
import math
from data_augmentation import *


def obtain_cutmix_box(signal_size, p=0.5, length_min=0.02, length_max=0.4):
    # 初始化全 0 的 1D 掩码
    mask = np.zeros((signal_size, 1))
    
    # 以概率 p 决定是否应用 CutMix
    if random.random() > p:
        return mask

    # 随机选择裁剪长度，范围为 length_min * signal_size 到 length_max * signal_size
    cutmix_len = int(np.random.uniform(length_min, length_max) * signal_size)
    
    # 随机选择起始位置，确保裁剪区域在信号内
    start = np.random.randint(0, signal_size - cutmix_len + 1)
    
    # 将掩码中对应区域设为 1
    mask[start:start + cutmix_len, :] = 1

    return mask


class SemiDataset(Dataset):
    def __init__(self, x, y, mode, nsample=None):
        self.x = x
        self.y = y
        self.mode = mode
        self.ids = list(range(len(x)))
        if mode == 'train_l' and nsample is not None:
            self.ids *= math.ceil(nsample / len(self.ids))
            self.ids = self.ids[:nsample]

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, item):
        ecg = self.x[item]
        mask = self.y[item]

        if self.mode == 'val':
            ecg = zscore_normalize(ecg, axis=0)
            ecg = torch.from_numpy(ecg.astype(np.float32)).permute(1, 0).unsqueeze(1) # (channel=1, 1, length)
            mask = torch.from_numpy(mask.astype(np.float32)).permute(1, 0).unsqueeze(1) # (channel=4, 1, length)
            return ecg, mask

        ecg, mask = random_resize(ecg, mask, scale_range=(0.5, 2))
        if random.random() < 0.5:
            ecg, mask = np.flip(ecg, axis=0).copy(), np.flip(mask, axis=0).copy()

        if self.mode == 'train_l':
            ecg = zscore_normalize(ecg, axis=0)
            ecg = torch.from_numpy(ecg.astype(np.float32)).permute(1, 0).unsqueeze(1)
            mask = torch.from_numpy(mask.astype(np.float32)).permute(1, 0).unsqueeze(1)
            return ecg, mask

        ecg_w, ecg_s1, ecg_s2 = deepcopy(ecg), deepcopy(ecg), deepcopy(ecg)

        if random.random() < 0.8:
            ecg_s1 = ecg_s1 + baseline_wander_noise(ecg_s1[:,0], fs=500, snr=-10, freq=0.15)[:,np.newaxis]
        if random.random() < 0.5:
            ecg_s1 = ecg_s1 + additive_white_gaussian_noise(ecg_s1[:,0], snr=10)[:,np.newaxis]
        cutmix_box1 = obtain_cutmix_box(ecg_s1.shape[0], p=0.5)

        if random.random() < 0.8:
            ecg_s2 = ecg_s2 + baseline_wander_noise(ecg_s2[:,0], fs=500, snr=-10, freq=0.15)[:,np.newaxis]
        if random.random() < 0.5:
            ecg_s2 = ecg_s2 + additive_white_gaussian_noise(ecg_s2[:,0], snr=10)[:,np.newaxis]
        cutmix_box2 = obtain_cutmix_box(ecg_s2.shape[0], p=0.5)

        ecg_w = zscore_normalize(ecg_w, axis=0)
        ecg_s1 = zscore_normalize(ecg_s1, axis=0)
        ecg_s2 = zscore_normalize(ecg_s2, axis=0)

        ecg_w = torch.from_numpy(ecg_w.astype(np.float32)).permute(1, 0).unsqueeze(1)
        ecg_s1 = torch.from_numpy(ecg_s1.astype(np.float32)).permute(1, 0).unsqueeze(1)
        ecg_s2 = torch.from_numpy(ecg_s2.astype(np.float32)).permute(1, 0).unsqueeze(1)
        mask = torch.from_numpy(mask.astype(np.float32)).permute(1, 0).unsqueeze(1)
        cutmix_box1 = torch.from_numpy(cutmix_box1.astype(np.float32)).permute(1, 0).unsqueeze(1)
        cutmix_box2 = torch.from_numpy(cutmix_box2.astype(np.float32)).permute(1, 0).unsqueeze(1)
        
        ignore_mask = torch.zeros((mask.shape[1], mask.shape[2]), dtype=torch.float32)
        return ecg_w, ecg_s1, ecg_s2, ignore_mask, cutmix_box1, cutmix_box2
    

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, length=0):
        self.length = length
        self.reset()

    def reset(self):
        if self.length > 0:
            self.history = []
        else:
            self.count = 0
            self.sum = 0.0
        self.val = 0.0
        self.avg = 0.0

    def update(self, val, num=1):
        if self.length > 0:
            # currently assert num==1 to avoid bad usage, refine when there are some explict requirements
            assert num == 1
            self.history.append(val)
            if len(self.history) > self.length:
                del self.history[0]

            self.val = self.history[-1]
            self.avg = np.mean(self.history)
        else:
            self.val = val
            self.sum += val * num
            self.count += num
            self.avg = self.sum / self.count


def train_Unimatch(model, trainset_u, trainset_l, valset, model_path, deep_supervision=False):
    batch_size = 32
    ini_lr = 1e-1
    lr = copy.deepcopy(ini_lr)
    epochs = 80
    # epochs = 1
    conf_thresh = 0.95
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    criterion_l = nn.CrossEntropyLoss()
    criterion_u = nn.CrossEntropyLoss(reduction='none')

    trainloader_u = DataLoader(trainset_u, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    trainloader_l = DataLoader(trainset_l, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    valloader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    total_iters = len(trainloader_u) * epochs
    best_val_loss = np.inf
    with tqdm(total=epochs, desc='Training Progress', unit='epoch') as pbar:
        for epoch in range(epochs):
            total_loss = AverageMeter()
            total_loss_x = AverageMeter()
            total_loss_s = AverageMeter()
            total_loss_w_fp = AverageMeter()
            total_mask_ratio = AverageMeter()

            loader = zip(trainloader_l, trainloader_u, trainloader_u)
            for i, ((ecg_x, mask_x),
                    (ecg_u_w, ecg_u_s1, ecg_u_s2, ignore_mask, cutmix_box1, cutmix_box2),
                    (ecg_u_w_mix, ecg_u_s1_mix, ecg_u_s2_mix, ignore_mask_mix, _, _)) in enumerate(loader):
                
                ecg_x, mask_x = ecg_x.cuda(), mask_x.cuda()
                ecg_u_w = ecg_u_w.cuda()
                ecg_u_s1, ecg_u_s2, ignore_mask = ecg_u_s1.cuda(), ecg_u_s2.cuda(), ignore_mask.cuda()
                cutmix_box1, cutmix_box2 = cutmix_box1.cuda(), cutmix_box2.cuda()
                ecg_u_w_mix = ecg_u_w_mix.cuda()
                ecg_u_s1_mix, ecg_u_s2_mix = ecg_u_s1_mix.cuda(), ecg_u_s2_mix.cuda()
                ignore_mask_mix = ignore_mask_mix.cuda()

                with torch.no_grad():
                    model.eval()

                    pred_u_w_mix = model(ecg_u_w_mix).detach()
                    conf_u_w_mix = pred_u_w_mix.max(dim=1)[0]
                    mask_u_w_mix = pred_u_w_mix.argmax(dim=1)

                ecg_u_s1[cutmix_box1 == 1] = \
                    ecg_u_s1_mix[cutmix_box1 == 1]
                ecg_u_s2[cutmix_box2 == 1] = \
                    ecg_u_s2_mix[cutmix_box2 == 1]

                model.train()

                num_lb, num_ulb = ecg_x.shape[0], ecg_u_w.shape[0]

                preds, preds_fp = model(torch.cat((ecg_x, ecg_u_w)), True)
                pred_x, pred_u_w = preds.split([num_lb, num_ulb])
                pred_u_w_fp = preds_fp[num_lb:]

                pred_u_s1, pred_u_s2 = model(torch.cat((ecg_u_s1, ecg_u_s2))).chunk(2)

                pred_u_w = pred_u_w.detach()
                conf_u_w = pred_u_w.max(dim=1)[0]
                mask_u_w = pred_u_w.argmax(dim=1)

                mask_u_w_cutmixed1, conf_u_w_cutmixed1, ignore_mask_cutmixed1 = \
                    mask_u_w.clone(), conf_u_w.clone(), ignore_mask.clone()
                mask_u_w_cutmixed2, conf_u_w_cutmixed2, ignore_mask_cutmixed2 = \
                    mask_u_w.clone(), conf_u_w.clone(), ignore_mask.clone()

                mask_u_w_cutmixed1[cutmix_box1.squeeze(1) == 1] = mask_u_w_mix[cutmix_box1.squeeze(1) == 1]
                conf_u_w_cutmixed1[cutmix_box1.squeeze(1) == 1] = conf_u_w_mix[cutmix_box1.squeeze(1) == 1]
                ignore_mask_cutmixed1[cutmix_box1.squeeze(1) == 1] = ignore_mask_mix[cutmix_box1.squeeze(1) == 1]

                mask_u_w_cutmixed2[cutmix_box2.squeeze(1) == 1] = mask_u_w_mix[cutmix_box2.squeeze(1) == 1]
                conf_u_w_cutmixed2[cutmix_box2.squeeze(1) == 1] = conf_u_w_mix[cutmix_box2.squeeze(1) == 1]
                ignore_mask_cutmixed2[cutmix_box2.squeeze(1) == 1] = ignore_mask_mix[cutmix_box2.squeeze(1) == 1]
                
                if deep_supervision:
                    pred_xs = model(ecg_x, full_output=True)[0:4]
                    loss_x = sum([criterion_l(pred, mask_x) for pred in pred_xs])
                else:
                    loss_x = criterion_l(pred_x, mask_x)

                loss_u_s1 = criterion_u(pred_u_s1, mask_u_w_cutmixed1)
                loss_u_s1 = loss_u_s1 * ((conf_u_w_cutmixed1 >= conf_thresh) & (ignore_mask_cutmixed1 != 255))
                loss_u_s1 = loss_u_s1.sum() / (ignore_mask_cutmixed1 != 255).sum().item()

                loss_u_s2 = criterion_u(pred_u_s2, mask_u_w_cutmixed2)
                loss_u_s2 = loss_u_s2 * ((conf_u_w_cutmixed2 >= conf_thresh) & (ignore_mask_cutmixed2 != 255))
                loss_u_s2 = loss_u_s2.sum() / (ignore_mask_cutmixed2 != 255).sum().item()

                loss_u_w_fp = criterion_u(pred_u_w_fp, mask_u_w)
                loss_u_w_fp = loss_u_w_fp * ((conf_u_w >= conf_thresh) & (ignore_mask != 255))
                loss_u_w_fp = loss_u_w_fp.sum() / (ignore_mask != 255).sum().item()

                loss = (loss_x + loss_u_s1 * 0.25 + loss_u_s2 * 0.25 + loss_u_w_fp * 0.5) / 2.0

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss.update(loss.item())
                total_loss_x.update(loss_x.item())
                total_loss_s.update((loss_u_s1.item() + loss_u_s2.item()) / 2.0)
                total_loss_w_fp.update(loss_u_w_fp.item())
                
                mask_ratio = ((conf_u_w >= conf_thresh) & (ignore_mask != 255)).sum().item() / \
                    (ignore_mask != 255).sum()
                total_mask_ratio.update(mask_ratio.item())

                iters = epoch * len(trainloader_u) + i
                lr = ini_lr * (1 - iters / total_iters) ** 0.9
                optimizer.param_groups[0]["lr"] = lr

            # validation
            model.eval()
            val_loss = 0.0
            for ecg, mask in valloader:
                ecg, mask = ecg.cuda(), mask.cuda()
                pred = model(ecg)
                loss = criterion_l(pred, mask)
                val_loss += loss.item()
            val_loss /= len(valloader)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), f'{model_path}.best.pth')

            pbar.set_postfix_str(f'loss: {total_loss.avg:.4f}, val_loss: {val_loss:.4f}, mask_ratio: {total_mask_ratio.avg:.4f}')
            pbar.update(1)

        torch.save(model.state_dict(), f'{model_path}.final.pth')

In [3]:
th_delineation = 150
gpu = 0
aug = 2
deep_supervision = 1
torch.cuda.set_device(gpu)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Parameters
model_save_path = f'./checkpoints/unet_a_ds{int(deep_supervision)}_UniMatch-V1.cross'
metrics_save_path = f'./metrics/unet_a_ds{int(deep_supervision)}_UniMatch-V1.cross'
records_list = [5,10,20,50,160]
val_ratio = 0.2
df_list = []

for num_labeled in records_list:
    ## Train
    print(f"Number of labeled data: {num_labeled}")
    print("Training Stage")
    for fold in range(5):
        if os.path.exists(f"{model_save_path}.num_labeled_{num_labeled}_fold_{fold}.final.pth"):
            continue
        # Set random seed for reproducibility
        seed = 42
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        x_train_l, y_train_l, _, _, _ = raw_data_load_ludb(40, num_labeled, fold, crop=[1250, 3750])
        x_train_u, y_train_u, _, _, _ = raw_data_load_rdb(400, 1999, fold, crop=[1250, 3750])

        num_val =  np.round(x_train_l.shape[0] * val_ratio).astype(int)
        x_val, y_val = x_train_l[:num_val], y_train_l[:num_val]
        x_train_l, y_train_l = x_train_l[num_val:], y_train_l[num_val:]

        print(f"Fold {fold+1}/{5}: Train labeled: {x_train_l.shape[0]}, Train unlabeled: {x_train_u.shape[0]}, Val: {x_val.shape[0]}")
        
        model = UNet1D_A(length=2500, base_channels=16, kernel_size=9, dropout='channels', droprate=.2, num_classes=2).to('cuda')
        ini_ds = deep_supervision
        ini_aug = 2
        model_load_path = f"./checkpoints/unet_a_ds{ini_ds}.num_labeled_{num_labeled}_aug_{ini_aug}.fold_{fold}.epoch_20.pth"
        model.load_state_dict(torch.load(model_load_path))
        trainset_u = SemiDataset(x_train_u, y_train_u, 'train_u')
        trainset_l = SemiDataset(x_train_l, y_train_l, 'train_l')
        valset = SemiDataset(x_val, y_val, 'val')

        model_path = f"{model_save_path}.num_labeled_{num_labeled}_fold_{fold}"
        train_Unimatch(model, trainset_u, trainset_l, valset, model_path, deep_supervision=deep_supervision)

    ## Test
    print("Test Stage")
    data = []
    label = []
    preds = []
    seg_metrics_macro = []
    deli_metrics_macro = []
    
    for fold in range(5):
        model = UNet1D_A(length=2500, base_channels=16, kernel_size=9, dropout='channels', droprate=.2, num_classes=2).to(device)
        model_load_path = f"{model_save_path}.num_labeled_{num_labeled}_fold_{fold}.best.pth"
        model.load_state_dict(torch.load(model_load_path))

        x_train, y_train, _, x_test, y_test = raw_data_load_ludb(40, 160, fold, crop=[0, 5000])
        test_dataset = ECGDataset(x_test, y_test, transform=base_transforms())

        pred = model_predict(model, model_load_path, test_dataset, device, multi_lead_correction=False)
        
        flag_ludb = np.load('./dataset/ludb/flag.npy')
        index_shuffled_5fold = np.load('./dataset/ludb/ludb_index_shuffled_5fold_250113.npy')
        index_shuffled = index_shuffled_5fold[:,fold]
        index_shuffled_lead = []
        for i in np.array(index_shuffled):
            index_shuffled_lead.extend([k for k in range(12*i,12*i+12,1)])
        num_test = 40
        flag_test= flag_ludb[index_shuffled_lead[0:num_test*12]]
        dataset = (x_test, y_test, np.zeros((x_test.shape[0],)), flag_test)
        _, _, seg_metrics, deli_metrics = dataset_eval(dataset, pred, th_delineation=th_delineation, verbose=0)

        data.append(x_test)
        label.append(y_test)
        preds.append(pred)
        seg_metrics_macro.append(seg_metrics)
        deli_metrics_macro.append(deli_metrics)

    data = np.concatenate(data, axis=0)
    label = np.concatenate(label, axis=0)
    preds = np.concatenate(preds, axis=0)

    def summarize_total_rows(dfs):
        """
        Extracts 'Total' rows from DataFrames, calculates summary statistics,
        and returns a new DataFrame.

        Args:
        dfs: A list of pandas DataFrames with identical structure.

        Returns:
            A pandas DataFrame containing the mean, std, max, and min
            of each column of the 'Total' rows from all input DataFrames.
        """
        total_rows = [df[df['type'] == 'Total'].iloc[0] for df in dfs]
        total_df = pd.DataFrame(total_rows)

        # Get original headers, and remove 'type'
        original_headers = total_df.columns.tolist()
        original_headers.remove('type')


        # Calculate summary stats for each column (excluding type)
        summary_data = {
            'mean': total_df[original_headers].mean().to_list(),
            'std': total_df[original_headers].std().to_list(),
            'max': total_df[original_headers].max().to_list(),
            'min': total_df[original_headers].min().to_list()
        }
        # Create the summary DataFrame
        summary_df = pd.DataFrame(summary_data, index = original_headers)
        return summary_df
    
    # Macro average metrics of 5 folds
    seg_metrics_macro = summarize_total_rows(seg_metrics_macro)
    deli_metrics_macro = summarize_total_rows(deli_metrics_macro) 
    filtered_df = deli_metrics_macro[deli_metrics_macro.index.str.contains('f1')]
    merged_df_macro = pd.concat([seg_metrics_macro, filtered_df], axis = 0)                

    # Micro average metrics of 5 folds
    dataset = (data, label, np.zeros((data.shape[0],)), np.zeros((data.shape[0],)))
    _, _, seg_metrics_micro, deli_metrics_micro = dataset_eval(dataset, preds, th_delineation=th_delineation, verbose=0)
    # Filter df2 to include rows where column name contain 'f1'
    filtered_df = deli_metrics_micro[['type'] + [col for col in deli_metrics_micro.columns if 'f1' in col]]
    merged_df_micro = pd.merge(seg_metrics_micro, filtered_df, on='type', how='outer')
    micro_row = merged_df_micro[merged_df_micro['type'] == 'Total'].iloc[0]
    # Remove 'type' and convert to series
    micro_row_values = micro_row.drop('type')
    merged_df = copy.deepcopy(merged_df_macro)
    merged_df['micro'] = micro_row_values
    
    df_list.append(merged_df)
    # Final results
    print(merged_df)


# Save results
# 拼接数据，将总标题作为列上方的“标题行”
concat_frames = []
for i, df in enumerate(df_list):
    # 插入标题行
    df_with_title = df.copy()
    df_with_title.columns = pd.MultiIndex.from_tuples([(str(records_list[i]), col) for col in df.columns])
    concat_frames.append(df_with_title)

# 按列拼接，并保留行标题
result = pd.concat(concat_frames, axis=1)

# 写入 Excel
result.to_excel(f"{metrics_save_path}.xlsx")

Number of labeled data: 5
Training Stage
Test Stage
                mean       std       max       min     micro
iou_p       0.625751  0.078055  0.708601  0.527379  0.624759
iou_qrs     0.842977  0.012109  0.857707  0.832309  0.840664
iou_t       0.749867  0.036341  0.798074  0.711801  0.750223
miou        0.739531  0.040579  0.788127  0.697687  0.738549
acc         0.891814  0.011076  0.907942  0.882268  0.890136
ave_f1      0.925151  0.024597  0.954095  0.901746  0.924202
f1_p_on     0.878912  0.039252  0.932827  0.832110  0.875951
f1_p_end    0.879126  0.039217  0.933739  0.832110  0.876129
f1_qrs_on   0.987336  0.007244  0.996551  0.979779  0.987479
f1_qrs_end  0.986742  0.008285  0.996551  0.977523   0.98653
f1_t_on     0.909607  0.033671  0.946257  0.861512  0.909732
f1_t_end    0.909182  0.031719  0.943831  0.864204  0.909391
Number of labeled data: 10
Training Stage
Test Stage
                mean       std       max       min     micro
iou_p       0.722013  0.024523  0.758033 

Training Progress: 100%|██████████| 80/80 [1:41:27<00:00, 76.10s/epoch, loss: 1.9688, val_loss: 0.7973, mask_ratio: 0.9628]


Fold 2/5: Train labeled: 1536, Train unlabeled: 23970, Val: 384


Training Progress: 100%|██████████| 80/80 [1:41:10<00:00, 75.88s/epoch, loss: 1.9669, val_loss: 0.8087, mask_ratio: 0.9643]


Fold 3/5: Train labeled: 1536, Train unlabeled: 23970, Val: 384


Training Progress: 100%|██████████| 80/80 [1:41:15<00:00, 75.94s/epoch, loss: 1.9715, val_loss: 0.7997, mask_ratio: 0.9631]


Fold 4/5: Train labeled: 1536, Train unlabeled: 23970, Val: 384


Training Progress: 100%|██████████| 80/80 [1:41:06<00:00, 75.83s/epoch, loss: 1.9682, val_loss: 0.7964, mask_ratio: 0.9642]


Fold 5/5: Train labeled: 1536, Train unlabeled: 23970, Val: 384


Training Progress: 100%|██████████| 80/80 [1:41:08<00:00, 75.86s/epoch, loss: 1.9666, val_loss: 0.8012, mask_ratio: 0.9636]


Test Stage
                mean       std       max       min     micro
iou_p       0.830528  0.023219  0.857632  0.793596  0.826733
iou_qrs     0.897515  0.009313  0.911197  0.886583   0.89626
iou_t       0.863331  0.018941  0.893228  0.845948  0.863583
miou        0.863791  0.015953  0.887352  0.843737  0.862192
acc         0.939568  0.006736  0.950891  0.933983  0.938974
ave_f1      0.984933  0.005411  0.992721  0.978894  0.984175
f1_p_on     0.972682  0.009268  0.986264  0.960600  0.970039
f1_p_end    0.972927  0.009075  0.985962  0.960600  0.970277
f1_qrs_on   0.996359  0.001842  0.998219  0.993599  0.996369
f1_qrs_end  0.996357  0.001912  0.998219  0.993599  0.995916
f1_t_on     0.985828  0.007594  0.994078  0.976924  0.986427
f1_t_end    0.985442  0.007429  0.994851  0.976924  0.986023
