In [1]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from network.model_unet_a_2d import *
from loss_utils import *
from data_loader import *
from data_augmentation import *
from test_utils import model_predict, dataset_eval
from torchinfo import summary
import random
import torch.backends.cudnn as cudnn
from torch.utils.data.sampler import Sampler
import itertools


batch_size = 24
labeled_bs = 12
top_num = 4
patch_size_h = 1
patch_size_w = 250
h_size = 1
w_size = 10

In [2]:
from einops import rearrange

def ABD_R(outputs1_max, outputs2_max, volume_batch, volume_batch_strong, outputs1_unlabel, outputs2_unlabel, args):
    # ABD-R Bidirectional Displacement Patch
    patches_1 = rearrange(outputs1_max[labeled_bs:], 'b (h p1) (w p2)->b (h w) (p1 p2)', p1=patch_size_h, p2=patch_size_w)
    patches_2 = rearrange(outputs2_max[labeled_bs:], 'b (h p1) (w p2)->b (h w) (p1 p2)', p1=patch_size_h, p2=patch_size_w)
    ecg_patch_1 = rearrange(volume_batch.squeeze(1)[labeled_bs:], 'b  (h p1) (w p2) -> b (h w)(p1 p2) ', p1=patch_size_h, p2=patch_size_w)  # torch.Size([8, 224, 224])
    ecg_patch_2 = rearrange(volume_batch_strong.squeeze(1)[labeled_bs:], 'b  (h p1) (w p2) -> b (h w)(p1 p2) ', p1=patch_size_h, p2=patch_size_w)
    patches_mean_1 = torch.mean(patches_1.detach(), dim=2)  # torch.Size([8, 16])
    patches_mean_2 = torch.mean(patches_2.detach(), dim=2)

    patches_outputs_1 = rearrange(outputs1_unlabel, 'b c (h p1) (w p2)->b c (h w) (p1 p2)', p1=patch_size_h, p2=patch_size_w)
    patches_outputs_2 = rearrange(outputs2_unlabel, 'b c (h p1) (w p2)->b c (h w) (p1 p2)', p1=patch_size_h, p2=patch_size_w)
    patches_mean_outputs_1 = torch.mean(patches_outputs_1.detach(), dim=3).permute(0, 2, 1)  # torch.Size([8, 16, 4])
    patches_mean_outputs_2 = torch.mean(patches_outputs_2.detach(), dim=3).permute(0, 2, 1)  # torch.Size([8, 16, 4])

    patches_mean_1_top4_values, patches_mean_1_top4_indices = patches_mean_1.topk(top_num, dim=1)  # torch.Size([8, 4])
    patches_mean_2_top4_values, patches_mean_2_top4_indices = patches_mean_2.topk(top_num, dim=1)  # torch.Size([8, 4])
    for i in range(labeled_bs):
        kl_similarities_1 = torch.empty(top_num)
        kl_similarities_2 = torch.empty(top_num)
        b = torch.argmin(patches_mean_1[i].detach(), dim=0)
        d = torch.argmin(patches_mean_2[i].detach(), dim=0)
        patches_mean_outputs_min_1 = patches_mean_outputs_1[i, b, :]  # torch.Size([4])
        patches_mean_outputs_min_2 = patches_mean_outputs_2[i, d, :]  # torch.Size([4])
        patches_mean_outputs_top4_1 = patches_mean_outputs_1[i, patches_mean_1_top4_indices[i, :], :]  # torch.Size([4, 4])
        patches_mean_outputs_top4_2 = patches_mean_outputs_2[i, patches_mean_2_top4_indices[i, :], :]  # torch.Size([4, 4])

        for j in range(top_num):
            kl_similarities_1[j] = torch.nn.functional.kl_div(patches_mean_outputs_top4_1[j].softmax(dim=-1).log(), patches_mean_outputs_min_2.softmax(dim=-1), reduction='sum')
            kl_similarities_2[j] = torch.nn.functional.kl_div(patches_mean_outputs_top4_2[j].softmax(dim=-1).log(), patches_mean_outputs_min_1.softmax(dim=-1), reduction='sum')

        a = torch.argmin(kl_similarities_1.detach(), dim=0, keepdim=False)
        c = torch.argmin(kl_similarities_2.detach(), dim=0, keepdim=False)
        a_ori = patches_mean_1_top4_indices[i, a]
        c_ori = patches_mean_2_top4_indices[i, c]

        max_patch_1 = ecg_patch_2[i][c_ori]  
        ecg_patch_1[i][b] = max_patch_1  
        max_patch_2 = ecg_patch_1[i][a_ori]
        ecg_patch_2[i][d] = max_patch_2 

    ecg_patch = torch.cat([ecg_patch_1, ecg_patch_2], dim=0)
    ecg_patch_last = rearrange(ecg_patch, 'b (h w)(p1 p2) -> b  (h p1) (w p2)', h=h_size, w=w_size,p1=patch_size_h, p2=patch_size_w) 
    return ecg_patch_last

def ABD_R_BCP(out_max_1, out_max_2, net_input_1, net_input_2, out_1, out_2):
    patches_1 = rearrange(out_max_1, 'b (h p1) (w p2)->b (h w) (p1 p2)', p1=patch_size_h, p2=patch_size_w)
    patches_2 = rearrange(out_max_2, 'b (h p1) (w p2)->b (h w) (p1 p2)', p1=patch_size_h, p2=patch_size_w)
    ecg_patch_1 = rearrange(net_input_1.squeeze(1), 'b  (h p1) (w p2) -> b (h w)(p1 p2) ',p1=patch_size_h, p2=patch_size_w)  # torch.Size([12, 224, 224])
    ecg_patch_2 = rearrange(net_input_2.squeeze(1),'b  (h p1) (w p2) -> b (h w)(p1 p2) ', p1=patch_size_h, p2=patch_size_w)

    patches_mean_1 = torch.mean(patches_1.detach(), dim=2)
    patches_mean_2 = torch.mean(patches_2.detach(), dim=2)

    patches_outputs_1 = rearrange(out_1, 'b c (h p1) (w p2)->b c (h w) (p1 p2)', p1=patch_size_h, p2=patch_size_w)
    patches_outputs_2 = rearrange(out_2, 'b c (h p1) (w p2)->b c (h w) (p1 p2)', p1=patch_size_h, p2=patch_size_w)
    patches_mean_outputs_1 = torch.mean(patches_outputs_1.detach(), dim=3).permute(0, 2, 1)  # torch.Size([8, 16, 4])
    patches_mean_outputs_2 = torch.mean(patches_outputs_2.detach(), dim=3).permute(0, 2, 1)  # torch.Size([8, 16, 4])

    patches_mean_1_top4_values, patches_mean_1_top4_indices = patches_mean_1.topk(top_num, dim=1)  # torch.Size([8, 4])
    patches_mean_2_top4_values, patches_mean_2_top4_indices = patches_mean_2.topk(top_num, dim=1)  # torch.Size([8, 4])

    for i in range(labeled_bs):
        if random.random() < 0.5:
            kl_similarities_1 = torch.empty(top_num)
            kl_similarities_2 = torch.empty(top_num)
            b = torch.argmin(patches_mean_1[i].detach(), dim=0)
            d = torch.argmin(patches_mean_2[i].detach(), dim=0)
            patches_mean_outputs_min_1 = patches_mean_outputs_1[i, b, :]  # torch.Size([4])
            patches_mean_outputs_min_2 = patches_mean_outputs_2[i, d, :]  # torch.Size([4])

            patches_mean_outputs_top4_1 = patches_mean_outputs_1[i, patches_mean_1_top4_indices[i, :], :]  # torch.Size([4, 4])
            patches_mean_outputs_top4_2 = patches_mean_outputs_2[i, patches_mean_2_top4_indices[i, :], :]  # torch.Size([4, 4])

            for j in range(top_num):
                kl_similarities_1[j] = torch.nn.functional.kl_div(patches_mean_outputs_top4_1[j].softmax(dim=-1).log(), patches_mean_outputs_min_2.softmax(dim=-1), reduction='sum')
                kl_similarities_2[j] = torch.nn.functional.kl_div(patches_mean_outputs_top4_2[j].softmax(dim=-1).log(), patches_mean_outputs_min_1.softmax(dim=-1), reduction='sum')

            a = torch.argmin(kl_similarities_1.detach(), dim=0, keepdim=False)
            c = torch.argmin(kl_similarities_2.detach(), dim=0, keepdim=False)

            a_ori = patches_mean_1_top4_indices[i, a]
            c_ori = patches_mean_2_top4_indices[i, c]

            max_patch_1 = ecg_patch_2[i][c_ori]  
            ecg_patch_1[i][b] = max_patch_1  
            max_patch_2 = ecg_patch_1[i][a_ori]
            ecg_patch_2[i][d] = max_patch_2
        else:
            a = torch.argmax(patches_mean_1[i].detach(), dim=0)
            b = torch.argmin(patches_mean_1[i].detach(), dim=0)
            c = torch.argmax(patches_mean_2[i].detach(), dim=0)
            d = torch.argmin(patches_mean_2[i].detach(), dim=0)

            max_patch_1 = ecg_patch_2[i][c]  
            ecg_patch_1[i][b] = max_patch_1  
            max_patch_2 = ecg_patch_1[i][a]
            ecg_patch_2[i][d] = max_patch_2
    ecg_patch = torch.cat([ecg_patch_1, ecg_patch_2], dim=0)
    ecg_patch_last = rearrange(ecg_patch, 'b (h w)(p1 p2) -> b  (h p1) (w p2)', h=h_size, w=w_size,p1=patch_size_h, p2=patch_size_w)
    return ecg_patch_last


def ABD_I(outputs1_max, outputs2_max, volume_batch, volume_batch_strong, label_batch, label_batch_strong, args):
    # ABD-I Bidirectional Displacement Patch
    patches_supervised_1 = rearrange(outputs1_max[:labeled_bs], 'b (h p1) (w p2)->b (h w) (p1 p2)', p1=patch_size_h, p2=patch_size_w)
    patches_supervised_2 = rearrange(outputs2_max[:labeled_bs], 'b (h p1) (w p2)->b (h w) (p1 p2)', p1=patch_size_h, p2=patch_size_w)
    ecg_patch_supervised_1 = rearrange(volume_batch.squeeze(1)[:labeled_bs], 'b  (h p1) (w p2) -> b (h w)(p1 p2) ', p1=patch_size_h, p2=patch_size_w)  # torch.Size([8, 224, 224])
    ecg_patch_supervised_2 = rearrange(volume_batch_strong.squeeze(1)[:labeled_bs], 'b  (h p1) (w p2) -> b (h w)(p1 p2) ', p1=patch_size_h, p2=patch_size_w)
    label_patch_supervised_1 = rearrange(label_batch[:labeled_bs], 'b  (h p1) (w p2) -> b (h w)(p1 p2) ', p1=patch_size_h, p2=patch_size_w)
    label_patch_supervised_2 = rearrange(label_batch_strong[:labeled_bs], 'b  (h p1) (w p2) -> b (h w)(p1 p2) ', p1=patch_size_h, p2=patch_size_w)
    patches_mean_supervised_1 = torch.mean(patches_supervised_1.detach(), dim=2)
    patches_mean_supervised_2 = torch.mean(patches_supervised_2.detach(), dim=2)
    e = torch.argmax(patches_mean_supervised_1.detach(), dim=1)
    f = torch.argmin(patches_mean_supervised_1.detach(), dim=1)
    g = torch.argmax(patches_mean_supervised_2.detach(), dim=1)
    h = torch.argmin(patches_mean_supervised_2.detach(), dim=1)
    for i in range(labeled_bs): 
        if random.random() < 0.5:
            min_patch_supervised_1 = ecg_patch_supervised_2[i][h[i]]  
            ecg_patch_supervised_1[i][e[i]] = min_patch_supervised_1
            min_patch_supervised_2 = ecg_patch_supervised_1[i][f[i]]
            ecg_patch_supervised_2[i][g[i]] = min_patch_supervised_2

            min_label_supervised_1 = label_patch_supervised_2[i][h[i]]
            label_patch_supervised_1[i][e[i]] = min_label_supervised_1
            min_label_supervised_2 = label_patch_supervised_1[i][f[i]]
            label_patch_supervised_2[i][g[i]] = min_label_supervised_2
    ecg_patch_supervised = torch.cat([ecg_patch_supervised_1, ecg_patch_supervised_2], dim=0)
    ecg_patch_supervised_last = rearrange(ecg_patch_supervised, 'b (h w)(p1 p2) -> b  (h p1) (w p2)', h=h_size, w=w_size,p1=patch_size_h, p2=patch_size_w)  # torch.Size([16, 224, 224])
    label_patch_supervised = torch.cat([label_patch_supervised_1, label_patch_supervised_2], dim=0)
    label_patch_supervised_last = rearrange(label_patch_supervised, 'b (h w)(p1 p2) -> b  (h p1) (w p2)', h=h_size, w=w_size,p1=patch_size_h, p2=patch_size_w)  # torch.Size([16, 224, 224])
    return ecg_patch_supervised_last, label_patch_supervised_last

In [3]:
# dataset
def iterate_once(iterable):
    return np.random.permutation(iterable)


def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())


def grouper(iterable, n):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3) --> ABC DEF"
    args = [iter(iterable)] * n
    return zip(*args)


class TwoStreamBatchSampler(Sampler):
    """Iterate two sets of indices

    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in zip(grouper(primary_iter, self.primary_batch_size),
                    grouper(secondary_iter, self.secondary_batch_size))
        )

    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size


def cutout_ecg(ecg, label, p=0.5, size_min=0.02, size_max=0.4, value_min=0, value_max=1, pixel_level=True):
    """
    应用Cutout增强到ECG信号及其对应的标签。
    参数：
        ecg：ECG信号，形状为 (length, channels=1)
        label：标签，形状为 (length, channels=4)
        p：应用Cutout的概率
        size_min, size_max：擦除区域长度的最小和最大比例
        value_min, value_max：ECG填充值的范围
        pixel_level：是否使用逐位置随机值进行填充
    返回：
        ecg：增强后的ECG信号
        label：增强后的标签
    """
    if random.random() < p:
        # 确保输入为NumPy数组
        ecg = np.array(ecg)
        label = np.array(label)

        # 检查输入形状
        assert ecg.ndim == 2 and label.ndim == 2, "输入必须为二维数组"
        assert ecg.shape[1] == 1 and label.shape[1] == 4, "ECG的通道数必须为1，标签的通道数必须为4"
        assert ecg.shape[0] == label.shape[0], "ECG和标签的长度必须匹配"
        
        # 获取长度 (L)
        L = ecg.shape[0]

        # 确定擦除区域的宽度
        size = np.random.uniform(size_min, size_max) * L
        erase_w = max(1, int(size))  # 确保擦除宽度至少为1

        # 选择起始位置
        x = np.random.randint(0, L - erase_w + 1)

        # 为ECG生成填充值
        if pixel_level:
            value = np.random.randint(value_min, value_max + 1, (erase_w, 1))
        else:
            value = np.random.randint(value_min, value_max + 1)

        # 对ECG应用Cutout
        ecg[x:x + erase_w, :] = value

        # 对标签应用Cutout，设置为0
        label[x:x + erase_w, :] = 0

    return ecg, label


class WeakStrongAugment(object):
    def __init__(self):
        pass

    def __call__(self, ecg, label):
        ecg, label = random_resize(ecg, label, scale_range=(0.5, 2))
        if random.random() < 0.5:
            ecg, label = np.flip(ecg, axis=0).copy(), np.flip(label, axis=0).copy()

        # strong augmentation is color jitter
        ecg_strong, label_strong = cutout_ecg(ecg,label,p=0.5)
        if random.random() < 0.5:
            ecg_strong = ecg_strong + baseline_wander_noise(ecg_strong[:,0], fs=500, snr=-10, freq=0.15)[:,np.newaxis]
        if random.random() < 0.5:
            ecg_strong = ecg_strong + additive_white_gaussian_noise(ecg_strong[:,0], snr=10)[:,np.newaxis]

        ecg = zscore_normalize(ecg, axis=0)
        ecg_strong = zscore_normalize(ecg_strong, axis=0)

        ecg = torch.from_numpy(ecg.astype(np.float32)).permute(1, 0).unsqueeze(1) # (channel=1, 1, length)
        ecg_strong = torch.from_numpy(ecg_strong.astype(np.float32)).permute(1, 0).unsqueeze(1)
        label = torch.from_numpy(label.astype(np.float32)).permute(1, 0).unsqueeze(1)
        label_strong = torch.from_numpy(label_strong.astype(np.float32)).permute(1, 0).unsqueeze(1)

        sample = {
            "ecg": ecg,
            "ecg_strong": ecg_strong,
            "label": label,
            "label_strong": label_strong
        }
        return sample


class ECGDataset_ABD(Dataset):
    def __init__(self, x, y, transform=None, weak_strong_augment=None):
        self.x = x
        self.y = y
        self.transform = transform
        self.weak_strong_augment = weak_strong_augment

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        x_sample = self.x[idx]
        y_sample = self.y[idx]
        if self.weak_strong_augment:
            sample = self.weak_strong_augment(x_sample, y_sample)
            return sample
        
        if self.transform:
            for transform_func in self.transform:
                x_sample, y_sample = transform_func(x_sample, y_sample)

        x_sample = torch.from_numpy(x_sample.astype(np.float32)).permute(1, 0).unsqueeze(1) # (channel=1, 1, length)
        y_sample = torch.from_numpy(y_sample.astype(np.float32)).permute(1, 0).unsqueeze(1) # (channel=4, 1, length)
        sample = {
            "ecg": x_sample,
            "label": y_sample
        }
        return sample
    

def generate_mask(img):
    """
    随机生成遮罩，返回遮罩及对应的损失遮罩。
    针对 H=1 的数据，确保 patch_x 至少为1，并处理随机起始位置的边界条件。
    """
    batch_size, channel, img_x, img_y = img.shape
    loss_mask = torch.ones(batch_size, img_x, img_y).cuda()
    mask = torch.ones(img_x, img_y).cuda()
    # 如果 img_x == 1，则 force patch_x 为1；否则按比例计算
    patch_x = int(img_x * 2 / 3) if img_x > 1 else 1
    patch_y = int(img_y * 2 / 3)
    # 计算随机起始位置时，防止负值
    if img_x - patch_x <= 0:
        w = 0
    else:
        w = np.random.randint(0, img_x - patch_x)
    if img_y - patch_y <= 0:
        h = 0
    else:
        h = np.random.randint(0, img_y - patch_y)
    mask[w:w + patch_x, h:h + patch_y] = 0
    loss_mask[:, w:w + patch_x, h:h + patch_y] = 0
    return mask.long(), loss_mask.long()


class DiceLoss(nn.Module):
    def __init__(self, n_classes):
        super(DiceLoss, self).__init__()
        self.n_classes = n_classes

    def _one_hot_encoder(self, input_tensor):
        tensor_list = []
        for i in range(self.n_classes):
            temp_prob = input_tensor == i * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob)
        output_tensor = torch.cat(tensor_list, dim=1)
        return output_tensor.float()

    def _dice_loss(self, score, target):
        target = target.float()
        smooth = 1e-10
        intersect = torch.sum(score * target)
        y_sum = torch.sum(target * target)
        z_sum = torch.sum(score * score)
        loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss
    
    def _dice_mask_loss(self, score, target, mask):
        target = target.float()
        mask = mask.float()
        smooth = 1e-10
        intersect = torch.sum(score * target * mask)
        y_sum = torch.sum(target * target * mask)
        z_sum = torch.sum(score * score * mask)
        loss = (2 * intersect + smooth ) / (z_sum + y_sum + smooth)
        loss = 1 - loss
        return loss

    def forward(self, inputs, target, mask=None, weight=None, softmax=True, one_hot=False):
        # If softmax is True, apply softmax to the inputs
        if softmax:
            inputs = torch.softmax(inputs, dim=1)

        # One-hot encoding of target, only if one_hot is True
        if one_hot:
            target = self._one_hot_encoder(target)
        
        if weight is None:
            weight = [1] * self.n_classes

        assert inputs.size() == target.size(), 'predict & target shape do not match'
        
        class_wise_dice = []
        loss = 0.0
        if mask is not None:
            # Expand mask to match the number of classes
            mask = mask.repeat(1, self.n_classes, 1, 1).type(torch.float32)
            for i in range(0, self.n_classes): 
                dice = self._dice_mask_loss(inputs[:, i], target[:, i], mask[:, i])
                class_wise_dice.append(1.0 - dice.item())
                loss += dice * weight[i]
        else:
            for i in range(0, self.n_classes):
                dice = self._dice_loss(inputs[:, i], target[:, i])
                class_wise_dice.append(1.0 - dice.item())
                loss += dice * weight[i]
        
        return loss / self.n_classes
    

def mix_loss(output, img_l, patch_l, mask, l_weight=1.0, u_weight=0.5, unlab=False):
    CE = nn.CrossEntropyLoss(reduction='none')
    dice_loss = DiceLoss(n_classes=4)
    img_l, patch_l = img_l.type(torch.int64), patch_l.type(torch.int64)
    output_soft = F.softmax(output, dim=1)
    ecg_weight, patch_weight = l_weight, u_weight
    if unlab:
        ecg_weight, patch_weight = u_weight, l_weight
    patch_mask = 1 - mask
    loss_dice = dice_loss(output_soft, img_l, mask.unsqueeze(1)) * ecg_weight
    loss_dice += dice_loss(output_soft, patch_l, patch_mask.unsqueeze(1)) * patch_weight
    # loss_ce = ecg_weight * (CE(output, img_l) * mask).sum() / (mask.sum() + 1e-16) 
    # loss_ce += patch_weight * (CE(output, patch_l) * patch_mask).sum() / (patch_mask.sum() + 1e-16)#loss = loss_ce
    loss_ce = ecg_weight * (CE(output, img_l.argmax(dim=1)) * mask).sum() / (mask.sum() + 1e-16) 
    loss_ce += patch_weight * (CE(output, patch_l.argmax(dim=1)) * patch_mask).sum() / (patch_mask.sum() + 1e-16)#loss = loss_ce
    return loss_dice, loss_ce


def pre_train(model, snapshot_path, db_train_l, db_train_u, db_val):
    base_lr = 0.01
    num_classes = 4
    # max_iterations = 10000
    max_iterations = 1000
    # max_iterations = 1

    batch_size = 24
    labeled_bs = 12
    seed = 42
    labeled_sub_bs, unlabeled_sub_bs = int(labeled_bs/2), int((batch_size-labeled_bs) / 2)

    def worker_init_fn(worker_id):
        random.seed(seed + worker_id)

    db_train = torch.utils.data.ConcatDataset([db_train_l, db_train_u])
    num_total = len(db_train)
    num_labeled = len(db_train_l)
    labeled_idxs = list(range(0, num_labeled))
    unlabeled_idxs = list(range(num_labeled, num_total))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-labeled_bs)

    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)

    optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)

    model.train()
    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_val_loss = np.inf
    iterator = tqdm(range(max_epoch), ncols=70)
    for _ in iterator:
        for _, sampled_batch in enumerate(trainloader):
            volume_batch, label_batch = sampled_batch['ecg'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()

            img_a, img_b = volume_batch[:labeled_sub_bs], volume_batch[labeled_sub_bs:labeled_bs]
            lab_a, lab_b = label_batch[:labeled_sub_bs], label_batch[labeled_sub_bs:labeled_bs]
            img_mask, loss_mask = generate_mask(img_a)
            gt_mixl = lab_a * img_mask + lab_b * (1 - img_mask)

            #-- original
            net_input = img_a * img_mask + img_b * (1 - img_mask)
            out_mixl = model(net_input)
            loss_dice, loss_ce = mix_loss(out_mixl, lab_a, lab_b, loss_mask, u_weight=1.0, unlab=True)

            loss = (loss_dice + loss_ce) / 2            

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            iter_num += 1

            if iter_num > 0 and iter_num % 100 == 0:
            # if 1:
                model.eval()
                val_loss = 0
                for _, sampled_batch in enumerate(valloader):
                    val_volume, val_label = sampled_batch['ecg'], sampled_batch['label']
                    val_volume, val_label = val_volume.cuda(), val_label.cuda()
                    val_output = model(val_volume)
                    val_loss += DiceLoss(n_classes=4)(val_output, val_label).item()
                val_loss /= len(valloader)
                iterator.set_postfix({'Iter': iter_num, 'Val Loss': f'{val_loss:.4f}'})

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    torch.save(model.state_dict(), snapshot_path)

                model.train()

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break


def get_ACDC_masks(output):
    probs = F.softmax(output, dim=1)
    _, probs = torch.max(output, dim=1)
    probs = F.one_hot(probs.squeeze(1).squeeze(1), num_classes=4).permute(0, 2, 1).unsqueeze(2).float()
    # if nms == 1:
    #     probs = get_ACDC_2DLargestCC(probs)      
    return probs

def sigmoid_rampup(current, rampup_length):
    """Exponential rampup from https://arxiv.org/abs/1610.02242"""
    if rampup_length == 0:
        return 1.0
    else:               
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))
    
def get_current_consistency_weight(epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    consistency = 0.1
    consistency_rampup = 200.0
    return 5* consistency * sigmoid_rampup(epoch, consistency_rampup)

def update_model_ema(model, ema_model, alpha):
    model_state = model.state_dict()
    model_ema_state = ema_model.state_dict()
    new_dict = {}
    for key in model_state:
        new_dict[key] = alpha * model_ema_state[key] + (1 - alpha) * model_state[key]
    ema_model.load_state_dict(new_dict)


def self_train(model_1, model_2, ema_model, snapshot_path, final_path_1, final_path_2, db_train_l, db_train_u, db_val):
    base_lr = 0.01
    num_classes = 4
    # max_iterations = 30000
    max_iterations = 1000
    # max_iterations = 1
    
    seed = 42
    u_weight = 0.5
    labeled_sub_bs, unlabeled_sub_bs = int(labeled_bs/2), int((batch_size-labeled_bs) / 2)

    def worker_init_fn(worker_id):
        random.seed(seed + worker_id)

    db_train = torch.utils.data.ConcatDataset([db_train_l, db_train_u])
    num_total = len(db_train)
    num_labeled = len(db_train_l)
    labeled_idxs = list(range(0, num_labeled))
    unlabeled_idxs = list(range(num_labeled, num_total))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, batch_size, batch_size-labeled_bs)

    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)

    optimizer1 = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
    optimizer2 = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)

    model_1.load_state_dict(torch.load(snapshot_path))    
    model_2.load_state_dict(torch.load(snapshot_path))
    ema_model.load_state_dict(torch.load(snapshot_path))

    model_1.train()
    model_2.train()
    ema_model.train()

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_val_loss_1 = np.inf
    best_val_loss_2 = np.inf
    iterator = tqdm(range(max_epoch), ncols=70)
    for _ in iterator:
        for _, sampled_batch in enumerate(trainloader):
            volume_batch, label_batch = sampled_batch['ecg'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            volume_batch_strong, label_batch_strong = sampled_batch['ecg_strong'], sampled_batch['label_strong']
            volume_batch_strong, label_batch_strong = volume_batch_strong.cuda(), label_batch_strong.cuda()

            img_a, img_b = volume_batch[:labeled_sub_bs], volume_batch[labeled_sub_bs:labeled_bs]
            uimg_a, uimg_b = volume_batch[labeled_bs:labeled_bs + unlabeled_sub_bs], volume_batch[labeled_bs + unlabeled_sub_bs:]
            lab_a, lab_b = label_batch[:labeled_sub_bs], label_batch[labeled_sub_bs:labeled_bs]

            img_a_s, img_b_s = volume_batch_strong[:labeled_sub_bs], volume_batch_strong[labeled_sub_bs:labeled_bs]
            uimg_a_s, uimg_b_s = volume_batch_strong[labeled_bs:labeled_bs + unlabeled_sub_bs], volume_batch_strong[labeled_bs + unlabeled_sub_bs:]
            lab_a_s, lab_b_s = label_batch_strong[:labeled_sub_bs], label_batch_strong[labeled_sub_bs:labeled_bs]
            
            with torch.no_grad():
                pre_a = ema_model(uimg_a)
                pre_b = ema_model(uimg_b)
                plab_a = get_ACDC_masks(pre_a)
                plab_b = get_ACDC_masks(pre_b)
                pre_a_s = ema_model(uimg_a_s)
                pre_b_s = ema_model(uimg_b_s)
                plab_a_s = get_ACDC_masks(pre_a_s)
                plab_b_s = get_ACDC_masks(pre_b_s)
                img_mask, loss_mask = generate_mask(img_a)
            consistency_weight = get_current_consistency_weight(iter_num//150)

            net_input_unl_1 = uimg_a * img_mask + img_a * (1 - img_mask)
            net_input_l_1 = img_b * img_mask + uimg_b * (1 - img_mask)
            net_input_1 = torch.cat([net_input_unl_1, net_input_l_1], dim=0) 

            net_input_unl_2 = uimg_a_s * img_mask + img_a_s * (1 - img_mask)
            net_input_l_2 = img_b_s * img_mask + uimg_b_s * (1 - img_mask)
            net_input_2 = torch.cat([net_input_unl_2, net_input_l_2], dim=0)


            # Model1 Loss
            out_unl_1 = model_1(net_input_unl_1)
            out_l_1 = model_1(net_input_l_1)
            out_1 = torch.cat([out_unl_1, out_l_1], dim=0)
            out_soft_1 = torch.softmax(out_1, dim=1)
            out_max_1 = torch.max(out_soft_1.detach(), dim=1)[0]
            out_pseudo_1 = torch.argmax(out_soft_1.detach(), dim=1, keepdim=False) 
            unl_dice_1, unl_ce_1 = mix_loss(out_unl_1, plab_a, lab_a, loss_mask, u_weight=u_weight, unlab=True)
            l_dice_1, l_ce_1 = mix_loss(out_l_1, lab_b, plab_b, loss_mask, u_weight=u_weight)
            loss_ce_1 = unl_ce_1 + l_ce_1
            loss_dice_1 = unl_dice_1 + l_dice_1

            # Model2 Loss
            out_unl_2 = model_2(net_input_unl_2)
            out_l_2 = model_2(net_input_l_2)
            out_2 = torch.cat([out_unl_2, out_l_2], dim=0)
            out_soft_2 = torch.softmax(out_2, dim=1)
            out_max_2 = torch.max(out_soft_2.detach(), dim=1)[0]
            out_pseudo_2 = torch.argmax(out_soft_2.detach(), dim=1, keepdim=False) 
            unl_dice_2, unl_ce_2 = mix_loss(out_unl_2, plab_a_s, lab_a_s, loss_mask, u_weight=u_weight, unlab=True)
            l_dice_2, l_ce_2 = mix_loss(out_l_2, lab_b_s, plab_b_s, loss_mask, u_weight=u_weight)
            loss_ce_2 = unl_ce_2 + l_ce_2
            loss_dice_2 = unl_dice_2 + l_dice_2

            dice_loss = DiceLoss(n_classes=4)
            # Model1 & Model2 Cross Pseudo Supervision
            pseudo_supervision1 = dice_loss(out_soft_1, out_pseudo_2.unsqueeze(1), softmax=False, one_hot=True)  
            pseudo_supervision2 = dice_loss(out_soft_2, out_pseudo_1.unsqueeze(1), softmax=False, one_hot=True)  
            # ABD-R New Training Sample
            ecg_patch_last = ABD_R_BCP(out_max_1, out_max_2, net_input_1, net_input_2, out_1, out_2)
            ecg_output_1 = model_1(ecg_patch_last.unsqueeze(1))  
            ecg_output_soft_1 = torch.softmax(ecg_output_1, dim=1)
            pseudo_ecg_output_1 = torch.argmax(ecg_output_soft_1.detach(), dim=1, keepdim=False)
            ecg_output_2 = model_2(ecg_patch_last.unsqueeze(1))
            ecg_output_soft_2 = torch.softmax(ecg_output_2, dim=1)
            pseudo_ecg_output_2 = torch.argmax(ecg_output_soft_2.detach(), dim=1, keepdim=False)
            # Model1 & Model2 Second Step Cross Pseudo Supervision
            pseudo_supervision3 = dice_loss(ecg_output_soft_1, pseudo_ecg_output_2.unsqueeze(1), softmax=False, one_hot=True)
            pseudo_supervision4 = dice_loss(ecg_output_soft_2, pseudo_ecg_output_1.unsqueeze(1), softmax=False, one_hot=True)

            loss_1 = (loss_dice_1 + loss_ce_1) / 2 + pseudo_supervision1 + pseudo_supervision3
            loss_2 = (loss_dice_2 + loss_ce_2) / 2 + pseudo_supervision2 + pseudo_supervision4
            loss = loss_1 + loss_2


            optimizer1.zero_grad()
            optimizer2.zero_grad()

            loss.backward()
            optimizer1.step()
            optimizer2.step()

            iter_num += 1
            update_model_ema(model_1, ema_model, 0.99)

            if iter_num > 0 and iter_num % 100 == 0:
            # if 1:
                model_1.eval()
                val_loss = 0
                for _, sampled_batch in enumerate(valloader):
                    val_volume, val_label = sampled_batch['ecg'], sampled_batch['label']
                    val_volume, val_label = val_volume.cuda(), val_label.cuda()
                    val_output = model_1(val_volume)
                    val_loss += DiceLoss(n_classes=4)(val_output, val_label).item()
                val_loss /= len(valloader)

                if val_loss < best_val_loss_1:
                    best_val_loss_1 = val_loss
                    torch.save(model_1.state_dict(), final_path_1)
                model_1.train()

                model_2.eval()
                val_loss = 0
                for _, sampled_batch in enumerate(valloader):
                    val_volume, val_label = sampled_batch['ecg'], sampled_batch['label']
                    val_volume, val_label = val_volume.cuda(), val_label.cuda()
                    val_output = model_2(val_volume)
                    val_loss += DiceLoss(n_classes=4)(val_output, val_label).item()
                val_loss /= len(valloader)

                if val_loss < best_val_loss_2:
                    best_val_loss_2 = val_loss
                    torch.save(model_2.state_dict(), final_path_2)
                model_2.train()

                iterator.set_postfix({'Iter': iter_num, 'Val Loss1': f'{best_val_loss_1:.4f}', 'Val Loss2': f'{best_val_loss_2:.4f}'})    

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break

In [4]:
th_delineation = 150
gpu = 2
aug = 2
deep_supervision = 0
torch.cuda.set_device(gpu)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Parameters
model_save_path = f'./checkpoints/unet_a_ds{int(deep_supervision)}_ABD.cross'
metrics_save_path = f'./metrics/unet_a_ds{int(deep_supervision)}_ABD.cross'
records_list = [5,10,20,50,160]
val_ratio = 0.2
df_list = []

for num_labeled in records_list:
    ## Train
    print(f"Number of labeled data: {num_labeled}")
    print("Training Stage")
    for fold in range(5):
        if os.path.exists(f"{model_save_path}.num_labeled_{num_labeled}_fold_{fold}.final_1.pth"):
            continue
        # Set random seed for reproducibility
        seed = 42
        cudnn.benchmark = False
        cudnn.deterministic = True
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        x_train_l, y_train_l, _, _, _ = raw_data_load_ludb(40, num_labeled, fold, crop=[1250, 3750])
        x_train_u, y_train_u, _, _, _ = raw_data_load_rdb(400, 1999, fold, crop=[1250, 3750])

        num_val =  np.round(x_train_l.shape[0] * val_ratio).astype(int)
        x_val, y_val = x_train_l[:num_val], y_train_l[:num_val]
        x_train_l, y_train_l = x_train_l[num_val:], y_train_l[num_val:]

        print(f"Fold {fold+1}/{5}: Train labeled: {x_train_l.shape[0]}, Train unlabeled: {x_train_u.shape[0]}, Val: {x_val.shape[0]}")

        db_train_l = ECGDataset_ABD(x_train_l, y_train_l, weak_strong_augment=WeakStrongAugment())
        db_train_u = ECGDataset_ABD(x_train_u, y_train_u, weak_strong_augment=WeakStrongAugment())
        db_val = ECGDataset_ABD(x_val, y_val, transform=base_transforms())

        model = UNet1D_A(length=2500, base_channels=16, kernel_size=9, dropout='channels', droprate=.2, num_classes=2, logits=True).to('cuda')
        ini_ds, ini_aug = int(deep_supervision), int(aug)
        model_load_path = f"./checkpoints/unet_a_ds{ini_ds}.num_labeled_{num_labeled}_aug_{ini_aug}.fold_{fold}.epoch_20.pth"
        model.load_state_dict(torch.load(model_load_path))
        
        snapshot_path = f"{model_save_path}.num_labeled_{num_labeled}_fold_{fold}.pre_train.pth"
        pre_train(model, snapshot_path, db_train_l, db_train_u, db_val)

        model_1 = copy.deepcopy(model)
        model_2 = copy.deepcopy(model)
        ema_model = copy.deepcopy(model)
        final_path_1 = f"{model_save_path}.num_labeled_{num_labeled}_fold_{fold}.final_1.pth"
        final_path_2 = f"{model_save_path}.num_labeled_{num_labeled}_fold_{fold}.final_2.pth"
        self_train(model_1, model_2, ema_model, snapshot_path, final_path_1, final_path_2, db_train_l, db_train_u, db_val)

    ## Test
    print("Test Stage")
    data = []
    label = []
    preds = []
    seg_metrics_macro = []
    deli_metrics_macro = []
    
    for fold in range(5):
        model = UNet1D_A(length=2500, base_channels=16, kernel_size=9, dropout='channels', droprate=.2, num_classes=2, logits=True).to(device)
        model_load_path = f"{model_save_path}.num_labeled_{num_labeled}_fold_{fold}.final_1.pth"
        model.load_state_dict(torch.load(model_load_path))

        x_train, y_train, _, x_test, y_test = raw_data_load_ludb(40, 160, fold, crop=[0, 5000])
        test_dataset = ECGDataset(x_test, y_test, transform=base_transforms())

        pred = model_predict(model, model_load_path, test_dataset, device, multi_lead_correction=False, logits=True)
        
        flag_ludb = np.load('./dataset/ludb/flag.npy')
        index_shuffled_5fold = np.load('./dataset/ludb/ludb_index_shuffled_5fold_250113.npy')
        index_shuffled = index_shuffled_5fold[:,fold]
        index_shuffled_lead = []
        for i in np.array(index_shuffled):
            index_shuffled_lead.extend([k for k in range(12*i,12*i+12,1)])
        num_test = 40
        flag_test= flag_ludb[index_shuffled_lead[0:num_test*12]]
        dataset = (x_test, y_test, np.zeros((x_test.shape[0],)), flag_test)
        _, _, seg_metrics, deli_metrics = dataset_eval(dataset, pred, th_delineation=th_delineation, verbose=0)

        data.append(x_test)
        label.append(y_test)
        preds.append(pred)
        seg_metrics_macro.append(seg_metrics)
        deli_metrics_macro.append(deli_metrics)

    data = np.concatenate(data, axis=0)
    label = np.concatenate(label, axis=0)
    preds = np.concatenate(preds, axis=0)

    def summarize_total_rows(dfs):
        """
        Extracts 'Total' rows from DataFrames, calculates summary statistics,
        and returns a new DataFrame.

        Args:
        dfs: A list of pandas DataFrames with identical structure.

        Returns:
            A pandas DataFrame containing the mean, std, max, and min
            of each column of the 'Total' rows from all input DataFrames.
        """
        total_rows = [df[df['type'] == 'Total'].iloc[0] for df in dfs]
        total_df = pd.DataFrame(total_rows)

        # Get original headers, and remove 'type'
        original_headers = total_df.columns.tolist()
        original_headers.remove('type')


        # Calculate summary stats for each column (excluding type)
        summary_data = {
            'mean': total_df[original_headers].mean().to_list(),
            'std': total_df[original_headers].std().to_list(),
            'max': total_df[original_headers].max().to_list(),
            'min': total_df[original_headers].min().to_list()
        }
        # Create the summary DataFrame
        summary_df = pd.DataFrame(summary_data, index = original_headers)
        return summary_df
    
    # Macro average metrics of 5 folds
    seg_metrics_macro = summarize_total_rows(seg_metrics_macro)
    deli_metrics_macro = summarize_total_rows(deli_metrics_macro) 
    filtered_df = deli_metrics_macro[deli_metrics_macro.index.str.contains('f1')]
    merged_df_macro = pd.concat([seg_metrics_macro, filtered_df], axis = 0)                

    # Micro average metrics of 5 folds
    dataset = (data, label, np.zeros((data.shape[0],)), np.zeros((data.shape[0],)))
    _, _, seg_metrics_micro, deli_metrics_micro = dataset_eval(dataset, preds, th_delineation=th_delineation, verbose=0)
    # Filter df2 to include rows where column name contain 'f1'
    filtered_df = deli_metrics_micro[['type'] + [col for col in deli_metrics_micro.columns if 'f1' in col]]
    merged_df_micro = pd.merge(seg_metrics_micro, filtered_df, on='type', how='outer')
    micro_row = merged_df_micro[merged_df_micro['type'] == 'Total'].iloc[0]
    # Remove 'type' and convert to series
    micro_row_values = micro_row.drop('type')
    merged_df = copy.deepcopy(merged_df_macro)
    merged_df['micro'] = micro_row_values
    
    df_list.append(merged_df)
    # Final results
    print(merged_df)


# Save results
# 拼接数据，将总标题作为列上方的“标题行”
concat_frames = []
for i, df in enumerate(df_list):
    # 插入标题行
    df_with_title = df.copy()
    df_with_title.columns = pd.MultiIndex.from_tuples([(str(records_list[i]), col) for col in df.columns])
    concat_frames.append(df_with_title)

# 按列拼接，并保留行标题
result = pd.concat(concat_frames, axis=1)

# 写入 Excel
result.to_excel(f"{metrics_save_path}.xlsx")

Number of labeled data: 5
Training Stage
Test Stage
                mean       std       max       min     micro
iou_p       0.654467  0.052165  0.718982  0.598013  0.649111
iou_qrs     0.828871  0.029985  0.858051  0.785064  0.820951
iou_t       0.741357  0.040647  0.786865  0.686746  0.738431
miou        0.741565  0.039961  0.783642  0.689941  0.736165
acc         0.888106  0.015212  0.902747  0.865644  0.886333
ave_f1      0.902306  0.029515  0.931868  0.855995  0.897126
f1_p_on     0.843134  0.044981  0.886941  0.769195  0.835662
f1_p_end    0.843730  0.045337  0.886941  0.768962  0.836292
f1_qrs_on   0.982634  0.008273  0.990287  0.972554  0.978863
f1_qrs_end  0.980901  0.008443  0.990213  0.969108  0.976913
f1_t_on     0.882773  0.037549  0.924414  0.825019  0.878272
f1_t_end    0.880666  0.039063  0.924414  0.821716  0.876755
Number of labeled data: 10
Training Stage
Fold 2/5: Train labeled: 96, Train unlabeled: 23970, Val: 24


 98%|██▉| 124/126 [07:19<00:07,  3.54s/it, Iter=1000, Val Loss=0.1416]
 98%|▉| 124/126 [30:02<00:29, 14.54s/it, Iter=1000, Val Loss1=0.1190, 


Fold 3/5: Train labeled: 96, Train unlabeled: 23970, Val: 24


 98%|██▉| 124/126 [07:20<00:07,  3.56s/it, Iter=1000, Val Loss=0.1026]
 98%|▉| 124/126 [30:03<00:29, 14.54s/it, Iter=1000, Val Loss1=0.0964, 


Fold 4/5: Train labeled: 96, Train unlabeled: 23970, Val: 24


 98%|██▉| 124/126 [07:20<00:07,  3.55s/it, Iter=1000, Val Loss=0.2384]
 98%|▉| 124/126 [30:06<00:29, 14.57s/it, Iter=1000, Val Loss1=0.2311, 


Fold 5/5: Train labeled: 96, Train unlabeled: 23970, Val: 24


 98%|██▉| 124/126 [07:22<00:07,  3.57s/it, Iter=1000, Val Loss=0.1444]
 98%|▉| 124/126 [30:04<00:29, 14.55s/it, Iter=1000, Val Loss1=0.1191, 


Test Stage
                mean       std       max       min     micro
iou_p       0.674930  0.040818  0.709588  0.608927  0.671591
iou_qrs     0.832719  0.014127  0.854455  0.815470  0.826534
iou_t       0.763983  0.031715  0.810767  0.726340  0.761802
miou        0.757211  0.023890  0.777730  0.716912  0.753309
acc         0.895980  0.008986  0.906776  0.882516  0.894465
ave_f1      0.916965  0.017226  0.935150  0.891209  0.914637
f1_p_on     0.864705  0.036497  0.907457  0.812914  0.861329
f1_p_end    0.865101  0.037112  0.908299  0.811761  0.861718
f1_qrs_on   0.980763  0.003147  0.984172  0.976485  0.979684
f1_qrs_end  0.979066  0.004496  0.983726  0.972525  0.977158
f1_t_on     0.905944  0.019208  0.928054  0.884620  0.903749
f1_t_end    0.906213  0.019790  0.929053  0.885606  0.904185
Number of labeled data: 20
Training Stage
Fold 1/5: Train labeled: 192, Train unlabeled: 23970, Val: 48


 98%|████▉| 62/63 [07:08<00:06,  6.92s/it, Iter=1000, Val Loss=0.1158]
 98%|▉| 62/63 [29:30<00:28, 28.56s/it, Iter=1000, Val Loss1=0.0992, Va


Fold 2/5: Train labeled: 190, Train unlabeled: 23970, Val: 47


 99%|████▉| 66/67 [07:11<00:06,  6.53s/it, Iter=1000, Val Loss=0.1479]
 99%|▉| 66/67 [29:41<00:26, 26.99s/it, Iter=1000, Val Loss1=0.1384, Va


Fold 3/5: Train labeled: 192, Train unlabeled: 23970, Val: 48


 98%|████▉| 62/63 [07:12<00:06,  6.98s/it, Iter=1000, Val Loss=0.0633]
 98%|▉| 62/63 [29:43<00:28, 28.77s/it, Iter=1000, Val Loss1=0.0623, Va


Fold 4/5: Train labeled: 192, Train unlabeled: 23970, Val: 48


 98%|████▉| 62/63 [07:15<00:07,  7.02s/it, Iter=1000, Val Loss=0.1667]
 98%|▉| 62/63 [29:35<00:28, 28.64s/it, Iter=1000, Val Loss1=0.1497, Va


Fold 5/5: Train labeled: 192, Train unlabeled: 23970, Val: 48


 98%|████▉| 62/63 [07:13<00:06,  6.99s/it, Iter=1000, Val Loss=0.0979]
 98%|▉| 62/63 [29:40<00:28, 28.71s/it, Iter=1000, Val Loss1=0.0922, Va


Test Stage
                mean       std       max       min     micro
iou_p       0.724391  0.034063  0.745769  0.664049  0.719862
iou_qrs     0.851800  0.014756  0.871721  0.833039  0.850735
iou_t       0.789371  0.022893  0.825071  0.766566  0.789772
miou        0.788521  0.021052  0.814187  0.756479   0.78679
acc         0.907968  0.008426  0.920775  0.898117  0.907333
ave_f1      0.922450  0.018117  0.945488  0.898769  0.920805
f1_p_on     0.883404  0.040654  0.910360  0.813487  0.878722
f1_p_end    0.883422  0.040157  0.910327  0.814022  0.878722
f1_qrs_on   0.981309  0.005562  0.987116  0.974619  0.981141
f1_qrs_end  0.981213  0.005861  0.987941  0.976131  0.980916
f1_t_on     0.902477  0.024540  0.937306  0.872264  0.902425
f1_t_end    0.902873  0.024821  0.940286  0.874188  0.902902
Number of labeled data: 50
Training Stage
Fold 1/5: Train labeled: 479, Train unlabeled: 23970, Val: 120


 96%|████▊| 25/26 [07:10<00:17, 17.24s/it, Iter=1000, Val Loss=0.0959]
 96%|▉| 25/26 [29:26<01:10, 70.64s/it, Iter=1000, Val Loss1=0.0900, Va


Fold 2/5: Train labeled: 477, Train unlabeled: 23970, Val: 119


 96%|████▊| 25/26 [07:10<00:17, 17.23s/it, Iter=1000, Val Loss=0.1456]
 96%|▉| 25/26 [29:30<01:10, 70.81s/it, Iter=1000, Val Loss1=0.1330, Va


Fold 3/5: Train labeled: 474, Train unlabeled: 23970, Val: 119


 96%|████▊| 25/26 [07:10<00:17, 17.23s/it, Iter=1000, Val Loss=0.0807]
 96%|▉| 25/26 [29:23<01:10, 70.54s/it, Iter=1000, Val Loss1=0.0762, Va


Fold 4/5: Train labeled: 478, Train unlabeled: 23970, Val: 119


 96%|████▊| 25/26 [07:09<00:17, 17.19s/it, Iter=1000, Val Loss=0.1186]
 96%|▉| 25/26 [29:40<01:11, 71.22s/it, Iter=1000, Val Loss1=0.0982, Va


Fold 5/5: Train labeled: 474, Train unlabeled: 23970, Val: 119


 96%|████▊| 25/26 [07:11<00:17, 17.24s/it, Iter=1000, Val Loss=0.1258]
 96%|▉| 25/26 [29:22<01:10, 70.50s/it, Iter=1000, Val Loss1=0.1139, Va


Test Stage
                mean       std       max       min     micro
iou_p       0.756787  0.019218  0.782320  0.741078  0.751199
iou_qrs     0.870373  0.014751  0.889649  0.852953  0.869125
iou_t       0.796157  0.018259  0.817587  0.770363  0.796968
miou        0.807772  0.016599  0.826598  0.789459  0.805764
acc         0.916096  0.007330  0.925887  0.907930  0.914433
ave_f1      0.925689  0.008332  0.936405  0.913632  0.924212
f1_p_on     0.895723  0.027312  0.924596  0.863389  0.890855
f1_p_end    0.895370  0.026927  0.924596  0.862613   0.89047
f1_qrs_on   0.984330  0.003161  0.987333  0.979813  0.984388
f1_qrs_end  0.984238  0.002916  0.987331  0.980059  0.983892
f1_t_on     0.896607  0.013940  0.911434  0.877533  0.897267
f1_t_end    0.897867  0.013719  0.912137  0.880806  0.898403
Number of labeled data: 160
Training Stage
Fold 1/5: Train labeled: 1536, Train unlabeled: 23970, Val: 384


 88%|██████▏| 7/8 [07:36<01:05, 65.18s/it, Iter=1000, Val Loss=0.1287]
 88%|▉| 7/8 [30:11<04:18, 258.76s/it, Iter=1000, Val Loss1=0.1168, Val


Fold 2/5: Train labeled: 1536, Train unlabeled: 23970, Val: 384


 88%|██████▏| 7/8 [07:37<01:05, 65.34s/it, Iter=1000, Val Loss=0.1091]
 88%|▉| 7/8 [30:30<04:21, 261.46s/it, Iter=1000, Val Loss1=0.1081, Val


Fold 3/5: Train labeled: 1536, Train unlabeled: 23970, Val: 384


 88%|██████▏| 7/8 [07:37<01:05, 65.32s/it, Iter=1000, Val Loss=0.0979]
 88%|▉| 7/8 [30:12<04:18, 258.91s/it, Iter=1000, Val Loss1=0.0979, Val


Fold 4/5: Train labeled: 1536, Train unlabeled: 23970, Val: 384


 88%|██████▏| 7/8 [07:35<01:05, 65.14s/it, Iter=1000, Val Loss=0.0846]
 88%|▉| 7/8 [30:39<04:22, 262.77s/it, Iter=1000, Val Loss1=0.0786, Val


Fold 5/5: Train labeled: 1536, Train unlabeled: 23970, Val: 384


 88%|██████▏| 7/8 [07:37<01:05, 65.39s/it, Iter=1000, Val Loss=0.0963]
 88%|▉| 7/8 [30:15<04:19, 259.35s/it, Iter=1000, Val Loss1=0.0917, Val


Test Stage
                mean       std       max       min     micro
iou_p       0.771656  0.024890  0.791474  0.737266  0.766187
iou_qrs     0.878550  0.016620  0.899821  0.856955  0.877276
iou_t       0.811637  0.017688  0.837510  0.796128  0.812247
miou        0.820615  0.014458  0.842077  0.803021   0.81857
acc         0.921159  0.006984  0.932516  0.916084  0.920253
ave_f1      0.927909  0.008082  0.936183  0.917255  0.926874
f1_p_on     0.895075  0.021680  0.928172  0.874530  0.891557
f1_p_end    0.895416  0.021401  0.928172  0.875644  0.891885
f1_qrs_on   0.981458  0.007389  0.989147  0.971532  0.981219
f1_qrs_end  0.980753  0.007942  0.988918  0.971532  0.980099
f1_t_on     0.906710  0.006212  0.913568  0.899171  0.907646
f1_t_end    0.908041  0.005962  0.918201  0.903378  0.908839
