In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
import os 
import h5py
import torch 
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from scipy.ndimage import zoom, rotate
import itertools
from torch.utils.data import Sampler
from medpy import metric
import logging
from tqdm import tqdm
import random
import torch.backends.cudnn as cudnn
from tensorboardX import SummaryWriter
from skimage.measure import label 
import sys

# Params

In [2]:
class params: 
    def __init__(self): 
        self.root_dir = 'ACDC' 
        self.exp = 'MBCP' 
        self.model = 'unet' 
        self.pretrain_iterations = 50
        
        self.selftrain_iterations = 10 
        self.batch_size = 24
        self.deterministic = 1 
        self.base_lr = 1e-3
        self.img_size = [256,256] 
        self.seed = 42 
        self.num_classes = 4 
        self.patch_size = 16
        self.mask_ratio = 0.5

        # label and unlabel 
        self.labeled_bs = 12
        self.label_num = 7 
        self.u_weight = 0.5 

        # Cost 
        self.gpu = '0' 
        self.consistency = 0.1
        self.consistency_rampup = 200.0 
        self.magnitude = '6.0' 
        self.s_param = 6

        # Caussl parameters 
        self.consistency_type = 'mse'   
        self.max_step = 60 
        self.min_step = 60 
        self.start_step1 = 50 
        self.start_step2 = 50 
        self.cofficient = 3.0 
        self.max_iteration = 5000 
        self.thres_iteration = 20 

args = params() 

# ACDC Dataset

In [3]:
class ACDCDataset(Dataset): 
    """
    Use to load ACDC dataset 
    This dataset support 3 modes (strings only): 
    - train_lab: labeled training data 
    - train_unlab: unlabeled traning data 
    - val: validation data 

    Parameters: 
    - base_dir (str): folder save data 
    - split (str): type of data want to load
    - reservse (str): use to reverse the index of data 
    - transform (torchvision.transform): the transform apply for data 
    """
    def __init__(self, base_dir, split='train_lab', reverse=None, transform=None): 
        super(ACDCDataset, self).__init__() 
        self.base_dir = base_dir
        self.split = split
        self.reverse = reverse
        self.transform = transform
        self.sample_list = []

        # Read the file 
        if self.split == 'train_lab': 
            with open(os.path.join(self.base_dir, 'train_lab.list'), 'r') as file: 
                self.sample_list = file.readlines() 
        elif self.split == 'train_unlab': 
            with open(os.path.join(self.base_dir, 'train_unlab.list'), 'r') as file: 
                self.sample_list = file.readlines() 
        elif self.split == 'val': 
            with open(os.path.join(self.base_dir, 'val.list'), 'r') as file: 
                self.sample_list = file.readlines() 
        elif self.split == 'reconstruct': 
            with open(os.path.join(self.base_dir, 'all_slices.list'), 'r') as file: 
                self.sample_list = file.readlines() 
        else: 
            raise ValueError(f'Split: {self.split} is not support for ACDC dataset')
        
        self.sample_list = [item.replace('\n', '') for item in self.sample_list]
        print(f'Mode: {self.split}: {len(self.sample_list)} samples in total')


    def __len__(self): 
        if (self.split == 'train_lab') | (self.split == 'train_unlab'): 
            return len(self.sample_list) * 10 # Why use it ???? 
        
        return len(self.sample_list)

    def __getitem__(self, idx): 
        case = self.sample_list[idx%len(self.sample_list)] # Avoid problem of __len__ 
        if self.reverse: 
            case = self.sample_list[len(self.sample_list) - idx%len(self.sample_list) - 1] 

        # read the file 
        if (self.split == 'train_lab') | (self.split == 'train_unlab') | (self.split == 'reconstruct'): 
            h5f = h5py.File((self.base_dir + f'/data/slices/{case}.h5'), 'r')         
        elif (self.split == 'val'): 
            h5f = h5py.File((self.base_dir + f'/data/{case}.h5'), 'r')
        
        image = h5f['image'][:]
        label = h5f['label'][:]
        sample = {'image': image, 'label': label}

        if self.transform: 
            sample = self.transform(sample)
        image_, label_ = sample['image'], sample['label']
        return image_, label_

In [4]:
def random_rot_flip(image, label): 
    """
    Random rotate and Random flip 
    """
    
    # Random rotate
    k = np.random.randint(0, 4) 
    image = np.rot90(image, k)
    label = np.rot90(label, k)

    # Random flip 
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis).copy() 
    label = np.flip(label, axis).copy() 

    return image, label 

def random_rotate(image, label):
    angle = np.random.randint(-20, 20) 
    image = rotate(image, angle, order= 0, reshape= False)
    label = rotate(label,angle, order=0, reshape= False )
    return image, label


class RandomGenerator: 
    def __init__(self, output_size): 
        self.output_size = output_size
    
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        if np.random.random() > 0.5: 
            image, label = random_rot_flip(image, label)
        
        if np.random.random() > 0.5: 
            image, label = random_rotate(image, label) 
        
        # Zoom image to -> [256,256]
        x,y = image.shape
        image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order= 0)
        label = zoom(label, (self.output_size[0] /x , self.output_size[1] / y), order= 0)

        # Convert to pytorch 
        imageTensor = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) # image.shape = (1, H, W)
        labelTensor = torch.from_numpy(label.astype(np.uint8)) # label.shape = (H, W)
        sample = {'image': imageTensor, 'label': labelTensor}
        
        return sample

In [5]:
def iterate_once(indices): 
    """
    Permutate the iterable once 
    (permutate the labeled_idxs once)
    """
    return np.random.permutation(indices) 

def iterate_externally(indices): 
    """
    Create an infinite iterator that repeatedly permutes the indices.
    ( permutate the unlabeled_idxs to make different)
    """
    def infinite_shuffles(): 
        while True: 
            yield np.random.permutation(indices)
            
    return itertools.chain.from_iterable(infinite_shuffles())

def grouper(iterable, n): 
    args = [iter(iterable)] * n 
    return zip(*args)

class TwoStreamBatchSampler(Sampler): 
    def __init__(self, primary_indicies, secondary_indicies, batchsize, secondary_batchsize): 
        self.primary_indicies = primary_indicies
        self.secondary_indicies = secondary_indicies
        self.primary_batchsize = batchsize - secondary_batchsize
        self.secondary_batchsize = secondary_batchsize

        assert len(self.primary_indicies) >= self.primary_batchsize > 0 
        assert len(self.secondary_indicies) >= self.secondary_batchsize > 0 

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indicies)
        secondary_iter = iterate_externally(self.secondary_indicies)

        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch) 
            in zip(grouper(primary_iter, self.primary_batchsize),
                   grouper(secondary_iter, self.secondary_batchsize))
        )

    def __len__(self): 
        return len(self.primary_indicies) // self.primary_batchsize

In [6]:
# Split the data 
def patients_to_slices(dataset, patients_num): 
    ref_dict = {} 
    if "ACDC" in dataset: 
        ref_dict = {'1': 32, '3': 68, '7': 136, '14': 256, '21': 396, '28': 512, '35': 664, '70': 1312}
    else:
        print('Error')
    
    return ref_dict[str(patients_num)]

# BCP Loss

In [7]:
class DiceLoss(nn.Module): 
    def __init__(self, n_classes): 
        super(DiceLoss, self).__init__() 
        self.n_classes = n_classes
    
    def _one_hot_encoder(self, input_tensor): # torch.nn.functional.one_hot()
        """
        Apply one-hot encoder for input_tensor 
        Parameters: 
            - input_tensor.shape = (batchsize,1, H, W), the target image
        """
        tensor_list = [] 
        for i in range(self.n_classes): 
            temp_prob = input_tensor == i * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob)
        output_tensor = torch.cat(tensor_list, dim= 1)
        return output_tensor.float() 
    
    def _dice_loss(self, score, target): 
        target = target.float() 
        smooth = 1e-10 
        
        intersection = torch.sum(score * target)
        union = torch.sum(score* score) + torch.sum(target*target)
        dice = ( 2*intersection + smooth) / (union + smooth)
        loss = 1 - dice 
        return loss 
    
    def _dice_mask_loss(self, score, target, mask): 
        target = target.float() 
        mask = mask.float() 
        smooth = 1e-10 

        intersection = torch.sum(score * target * mask)
        union = torch.sum(score * score * mask ) + torch.sum(target * target * mask)
        dice = (2*intersection + smooth) / (union + smooth)
        loss = 1 - dice 
        return loss 

    def forward(self, inputs, target, mask= None, weight= None, softmax= False): 
        if softmax: 
            inputs = torch.softmax(inputs, dim= 1) 
        
        target = self._one_hot_encoder(target)

        # weight 
        if weight is  None: 
            weight = [1] * self.n_classes
        
        assert inputs.size() == target.size(), 'predict and target shape do not match'
        class_wise_dice = [] 
        loss = 0.0 
        if mask is not None: 
            mask = mask.repeat(1, self.n_classes, 1, 1).type(torch.float32)
            for i in range(0, self.n_classes): 
                dice = self._dice_mask_loss(inputs[:, i], target[:, i], mask[:, i])
                class_wise_dice.append( 1.0 - dice.item())
                loss += dice * weight[i]

        else: 
            for i in range(0, self.n_classes): 
                dice = self._dice_loss(inputs[:, i], target[:, i]) 
                class_wise_dice.append(1.0 - dice.item())
                loss += dice * weight[i] 
        
        return loss / self.n_classes
    

    
dice_loss = DiceLoss(n_classes= 4)

In [8]:
def mix_loss(output, img_l, patch_l, mask, l_weight=1.0, u_weight=0.5, unlab=False):
    CE = nn.CrossEntropyLoss(reduction='none')
    img_l, patch_l = img_l.type(torch.int64), patch_l.type(torch.int64)
    output_soft = F.softmax(output, dim=1)
    image_weight, patch_weight = l_weight, u_weight
    if unlab:
        image_weight, patch_weight = u_weight, l_weight
    patch_mask = 1 - mask
    loss_dice = dice_loss(output_soft, img_l.unsqueeze(1), mask.unsqueeze(1)) * image_weight
    loss_dice += dice_loss(output_soft, patch_l.unsqueeze(1), patch_mask.unsqueeze(1)) * patch_weight
    loss_ce = image_weight * (CE(output, img_l) * mask).sum() / (mask.sum() + 1e-16) 
    loss_ce += patch_weight * (CE(output, patch_l) * patch_mask).sum() / (patch_mask.sum() + 1e-16)#loss = loss_ce
    return loss_dice, loss_ce

# MAE Loss

In [9]:
def reconstruction_loss(X_rec, X_orig, mask, lam=0.1):
    """
    Compute full L_REC loss as in the SDCL paper:
    mask: 1 for visible, 0 for masked
    lam: lambda weight for visible region
    """
    loss_masked = ((1 - mask) * (X_rec - X_orig) ** 2).sum()
    loss_visible = (mask * (X_rec - X_orig) ** 2).sum()
    
    total_pixels = X_orig.numel()
    loss = (loss_masked + lam * loss_visible) / total_pixels
    return loss


# utils

In [10]:
def get_ACDC_2DLargestCC(segmentation):
    batch_list = []
    N = segmentation.shape[0]
    for i in range(0, N):
        class_list = []
        for c in range(1, 4):
            temp_seg = segmentation[i] #== c *  torch.ones_like(segmentation[i])
            temp_prob = torch.zeros_like(temp_seg)
            temp_prob[temp_seg == c] = 1
            temp_prob = temp_prob.detach().cpu().numpy()
            labels = label(temp_prob)          
            if labels.max() != 0:
                largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
                class_list.append(largestCC * c)
            else:
                class_list.append(temp_prob)
        
        n_batch = class_list[0] + class_list[1] + class_list[2]
        batch_list.append(n_batch)

    return torch.Tensor(batch_list).cuda()
    
def get_ACDC_masks(output, nms=0):
    probs = F.softmax(output, dim=1)
    _, probs = torch.max(probs, dim=1)
    if nms == 1:
        probs = get_ACDC_2DLargestCC(probs)      
    return probs

In [11]:
def save_net_opt(net, optimizer, path):
    state = {
        'net': net.state_dict(),
        'optim': optimizer.state_dict()
    }
    torch.save(state, str(path))

def load_net(net, path):
    state = torch.load(str(path))
    net.load_state_dict(state['net']) 

def load_net_opt(net, optimizer, path): 
    state = torch.load(str(path))
    net.load_state_dict(state['net'])
    optimizer.load_state_dict(state['optim'])

In [12]:
def generate_mask(img):
    batch_size, channel, img_x, img_y = img.shape[0], img.shape[1], img.shape[2], img.shape[3]
    loss_mask = torch.ones(batch_size, img_x, img_y).cuda()
    mask = torch.ones(img_x, img_y).cuda()
    patch_x, patch_y = int(img_x*2/3), int(img_y*2/3)
    w = np.random.randint(0, img_x - patch_x)
    h = np.random.randint(0, img_y - patch_y)
    mask[w:w+patch_x, h:h+patch_y] = 0
    loss_mask[:, w:w+patch_x, h:h+patch_y] = 0
    return mask.long(), loss_mask.long()

In [13]:
def sigmoid_rampup(current, rampup_length):
    if rampup_length == 0: 
        return 1.0 
    else:
        current = np.clip(current, 0, rampup_length)
        phase = 1 - (current / rampup_length)
        return float(np.exp(-5 * phase * phase))

In [14]:
# Mean-Teacher compomnent 
def get_current_consistency_weight(epoch, args): 
    return 5 * args.consistency + sigmoid_rampup(epoch, args.consistency_rampup)

def update_model_ema(model, ema_model, alpha): 
    model_state = model.state_dict() 
    model_ema_state = ema_model.state_dict()
    new_dict = {}

    for key in model_state:
        new_dict[key] = alpha * model_ema_state[key] + (1 - alpha) * model_state[key]

    ema_model.load_state_dict(new_dict)

In [15]:
def calculate_metric_percase(pred, gt):
    pred[pred > 0] = 1
    gt[gt > 0] = 1
    if pred.sum() > 0:
        dice = metric.binary.dc(pred, gt)
        hd95 = metric.binary.hd95(pred, gt)
        return dice, hd95
    else:
        return 0, 0


def test_single_volume(image, label, model, classes, img_size=[256, 256]):
    image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
    prediction = np.zeros_like(label)
    for ind in range(image.shape[0]):
        slice = image[ind, :, :]
        x, y = slice.shape[0], slice.shape[1]
        slice = zoom(slice, (img_size[0] / x, img_size[1] / y), order=0)
        input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
        model.eval()
        with torch.no_grad():
            output = model.forward_segmentation(input)
            if len(output)>1:
                output = output[0]
            out = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze(0)
            out = out.cpu().detach().numpy()
            pred = zoom(out, (x / img_size[0], y / img_size[1]), order=0)
            prediction[ind] = pred
    metric_list = []
    for i in range(1, classes):
        metric_list.append(calculate_metric_percase(prediction == i, label == i))
    return metric_list


# random masking

In [16]:
def gen_random_mask(x, patch_size, mask_ratio):
        N = x.shape[0]
        L = (x.shape[2] // patch_size) ** 2
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.randn(N, L, device=x.device)

        # sort noise for each sample
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # generate the binary mask: 0 is keep 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)
        return mask


def upsample_mask(mask, patch_size, H, W):
    p = int(mask.shape[1] ** 0.5)
    mask = mask.reshape(-1, p, p).unsqueeze(1)  # [B, 1, h, w]
    mask = mask.repeat_interleave(patch_size, 2).repeat_interleave(patch_size, 3)
    return mask  # [B, 1, H, W]

# Unet backbone

In [17]:
class GRN_ChannelFirst(nn.Module):
    def __init__(self, num_channels, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        self.eps = eps

    def forward(self, x):
        gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)   # L2 norm [B, C, 1, 1]
        nx = gx / (gx.mean(dim=1, keepdim=True) + self.eps) # Normalize across channels
        return self.gamma * (x * nx) + self.beta + x # Residual + learnable modulation


# Build again UNet backbone 
class ConvBlock(nn.Module): 
    """
    Two convolution block with batchnorm and leakyrelu 
    Dont change the output size  
    """
    def __init__(self, in_channel, out_channel, dropout_p, use_grn=True): 
        super(ConvBlock, self).__init__()
        self.use_grn = use_grn
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channel), 
            nn.LeakyReLU(), 
            nn.Dropout(dropout_p)
        )
    
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channel), 
            nn.LeakyReLU()
        )
        self.grn = GRN_ChannelFirst(out_channel)  # Add GRN layer

    def forward(self, x, apply_grn=False): 
        x = self.conv1(x)
        x = self.conv2(x)
        if apply_grn:
            x = self.grn(x)
        return x

class DownBlock(nn.Module): 
    """
    Downsample follow by ConvBlock
    """
    def __init__(self, in_channel, out_channel, dropout_p): 
        super(DownBlock, self).__init__() 
        
        self.pool = nn.MaxPool2d(kernel_size= 2)
        self.conv = ConvBlock(in_channel, out_channel, dropout_p)


        # self.maxpool_conv = nn.Sequential(
        #     nn.MaxPool2d(kernel_size= 2), 
        #     ConvBlock(in_channel, out_channel, dropout_p)
        # )

    def forward(self, x, apply_grn=False):
        x = self.pool(x)
        x = self.conv(x, apply_grn = apply_grn) 
        return x
    

class UpBlock(nn.Module): # Check if have problem  
    def __init__(self, in_channel1, in_channel2, out_channel, dropout_p): 
        super(UpBlock, self).__init__() 

        self.convx1 = nn.Conv2d(in_channel1, in_channel2, kernel_size= 1) # WRD
        self.up = nn.Upsample(scale_factor= 2, mode= 'bilinear', align_corners= True)
        self.conv = ConvBlock(in_channel2 * 2, out_channel, dropout_p)
         
    def forward(self, x1, x2, apply_grn=False): 
        x1 = self.convx1(x1)
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim= 1)
        return self.conv(x, apply_grn=apply_grn) 

class Encoder(nn.Module): 
    def __init__(self, params): 
        super(Encoder, self).__init__()
        self.params = params
        self.in_chs = self.params['in_chs']
        self.ft_chs = self.params['ft_chs']
        self.n_class = self.params['num_class']
        self.dropout = self.params['dropout']
        assert (len(self.ft_chs) == 5)

        self.conv = ConvBlock(self.in_chs, self.ft_chs[0], dropout_p= self.dropout[0])
        self.down1 = DownBlock(self.ft_chs[0], self.ft_chs[1], self.dropout[1])
        self.down2 = DownBlock(self.ft_chs[1], self.ft_chs[2], self.dropout[2])
        self.down3 = DownBlock(self.ft_chs[2], self.ft_chs[3], self.dropout[3])
        self.down4 = DownBlock(self.ft_chs[3], self.ft_chs[4], self.dropout[4])

    def forward(self, x, apply_grn=False): 
        x0 = self.conv(x, apply_grn=apply_grn) 
        x1 = self.down1(x0, apply_grn=apply_grn) 
        x2 = self.down2(x1, apply_grn=apply_grn) 
        x3 = self.down3(x2, apply_grn=apply_grn)
        x4 = self.down4(x3, apply_grn=apply_grn)
        return [x0, x1, x2, x3, x4]


class Decoder(nn.Module): 
    def __init__(self, params): 
        super(Decoder, self).__init__() 
        self.params = params 
        self.in_chs = self.params['in_chs']
        self.ft_chs = self.params['ft_chs']
        self.n_class = self.params['num_class']
        assert (len(self.ft_chs) == 5)

        self.up1 = UpBlock(self.ft_chs[4], self.ft_chs[3], self.ft_chs[3], dropout_p= 0.0)
        self.up2 = UpBlock(self.ft_chs[3], self.ft_chs[2], self.ft_chs[2], dropout_p= 0.0)
        self.up3 = UpBlock(self.ft_chs[2], self.ft_chs[1], self.ft_chs[1], dropout_p= 0.0)
        self.up4 = UpBlock(self.ft_chs[1], self.ft_chs[0], self.ft_chs[0], dropout_p= 0.0)
        self.out_conv = nn.Conv2d(self.ft_chs[0], self.n_class, kernel_size= 3, padding= 1)


    def forward(self, feature, apply_grn = False): 
        x0 = feature[0] 
        x1 = feature[1] 
        x2 = feature[2] 
        x3 = feature[3]
        x4 = feature[4] 

        x = self.up1(x4, x3, apply_grn) 
        x = self.up2(x, x2, apply_grn)
        x = self.up3(x, x1, apply_grn)
        x_last = self.up4(x, x0, apply_grn)
        output = self.out_conv(x_last)
        return output, x_last
    
class UNet2d(nn.Module): 
    def __init__(self, in_chs, class_num=None, recon=False): 
        super(UNet2d, self).__init__() 
        self.recon = recon
        self.params = {
            'in_chs': in_chs, 
            'ft_chs': [16, 32, 64, 128, 256], 
            'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], 
            'num_class': class_num if not recon else in_chs
        }

        self.encoder = Encoder(self.params) 
        self.decoder = Decoder(self.params)

        # ⬇️ final projection layer for reconstruction
        if recon:
            self.out_conv = nn.Conv2d(self.params['ft_chs'][0], in_chs, kernel_size=3, padding=1)
        else:
            self.out_conv = None

    def forward(self, x): 
        features = self.encoder(x) 
        x_out, x_last = self.decoder(features)

        if self.recon:
            return self.out_conv(x_last)  # Ensure output shape = input shape
        else:
            return x_out
        

# MBCP 

In [18]:
class MBCP(nn.Module):
    def __init__(self, encoder_type='unet', in_chs=1, class_num=4, use_recon = False):
        super(MBCP, self).__init__()

        self.shared_params = {
            'in_chs': in_chs,
            'ft_chs': [16, 32, 64, 128, 256],
            'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
            'num_class': class_num  # For seg decoder
        }

        # Shared encoder
        self.encoder = Encoder(self.shared_params)

        # Segmentation decoder
        self.decoder_seg = Decoder(self.shared_params)

        # Reconstruction decoder (output RGB)
        mae_params = self.shared_params.copy()
        mae_params['num_class'] = in_chs
        self.decoder_recon = Decoder(mae_params)

        
    def forward_segmentation(self, x_bcp):
        # shared encoder
        feat_bcp = self.encoder(x_bcp)
        seg_out, _ = self.decoder_seg(feat_bcp)  # Segmentation output [B, class_num, H, W] 
        return seg_out
    
    def forward_reconstruction(self, x_masked):
        feat_masked = self.encoder(x_masked)
        rec_out, _ = self.decoder_recon(feat_masked) # Reconstruction output [B, in_chns, H, W]
        return rec_out  

    def forward(self, x):  
        return self.forward_segmentation(x)

# engine

In [19]:
def MBCP_net(role, in_chs=1, class_num=4, backbone='unet', ema=False, use_recon=False):
    if role == 'mbcp':
        # encoder shared, 2 decoder forn segmentation and reconstruction
        model = MBCP(encoder_type=backbone, in_chs=in_chs, class_num=class_num, use_recon=use_recon).cuda()
    
    
    elif role == 'teacher':
        model = MBCP(encoder_type = backbone, in_chs=in_chs, class_num=class_num, use_recon=use_recon).cuda() # ignore decoder mae
        if ema:
            for param in model.parameters():
                param.detach_() 
    
    return model


# pre train

In [20]:
# Configuration
def pre_train(args, snapshot_path):
    base_lr = args.base_lr
    num_classes = args.num_classes
    max_iterations = args.pretrain_iterations
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 
    pre_trained_model = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
    labeled_sub_bs, unlabeled_sub_bs = int(args.labeled_bs / 2), int((args.batch_size - args.labeled_bs)/2)

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    # load dataset 
    db_train = ACDCDataset(base_dir= args.root_dir, split= 'train_lab',
                        transform= transforms.Compose([RandomGenerator(args.img_size)]))
    
    db_val = ACDCDataset(base_dir= args.root_dir, split= 'val')

    total_slices = len(db_train)
    labeled_slices = patients_to_slices(args.root_dir, args.label_num)
    print(f'Total slice is {total_slices}, Labeled slice is {labeled_slices}')

    # Create batch_sampler 
    labeled_idxs = list(range(0, labeled_slices))
    unlabeled_idxs = list(range(labeled_slices, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size - args.labeled_bs)

    # Create dataloader 
    trainloader = DataLoader(db_train, batch_sampler= batch_sampler, num_workers= 4, pin_memory= True, worker_init_fn= worker_init_fn)
    valloader = DataLoader(db_val, batch_size= 1, shuffle= False, num_workers=1)

    # Define model 
    model = MBCP_net(role='mbcp', in_chs=1, class_num=num_classes, use_recon=False)
    optimizer = torch.optim.AdamW(model.parameters(), lr= base_lr, weight_decay= 0.0001)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info('Start pre-training')
    logging.info(f'{len(trainloader)} iterations per epoch')

    # training process
    model.train() 
    iter_num = 0 
    max_epoch = max_iterations // len(trainloader) + 1 
    best_performance = 0.0 
    best_hd = 100.0
    iterator = tqdm(range(max_epoch), ncols= 70)
    
    for _ in iterator: 
        for _, sampled_batch in enumerate(trainloader): 
            volume_batch, label_batch = sampled_batch
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 

            img_a, img_b = volume_batch[: labeled_sub_bs], volume_batch[labeled_sub_bs : args.labeled_bs]
            lab_a, lab_b = label_batch[: labeled_sub_bs], label_batch[labeled_sub_bs : args.labeled_bs]
            img_mask, loss_mask = generate_mask(img_a)
            gt_mixl = lab_a * img_mask + lab_b * ( 1- img_mask)

            #-- original 
            net_input = img_a * img_mask + img_b * ( 1 - img_mask)  
            out_mixl = model.forward_segmentation(net_input)
            loss_dice, loss_ce = mix_loss(out_mixl, lab_a, lab_b, loss_mask,u_weight= 1.0, unlab= True )
            loss = (loss_dice + loss_ce )/2 

            optimizer.zero_grad() 
            loss.backward() 
            optimizer.step() 

            iter_num += 1 

            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/mix_dice', loss_dice,iter_num )
            writer.add_scalar('info/mix_ce', loss_ce, iter_num)

            logging.info('iteration %d: loss %f, mix_dice: %f, mix_ce: %f'%(iter_num, loss, loss_dice, loss_ce))
            if iter_num % 20 == 0: 
                image = net_input[1, 0:1, :, :]
                writer.add_image('pre_train/Mixed_Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(out_mixl, dim=1), dim=1, keepdim= True)
                writer.add_image('pre_train/Mixed_Prediction', outputs[1, ...]*50, iter_num)
                labs = gt_mixl[1, ...].unsqueeze(0) * 50 
                writer.add_image('pre_train/Mixed_GroundTruth', labs, iter_num)
            
            # Evaluate after 200 epoch ! 
            if iter_num > 0 and iter_num % 200 == 0: 
                model.eval() 
                metric_list = 0.0 
                for _, sampled_batch in enumerate(valloader):
                    image_batch, label_batch = sampled_batch
                    metric_i = test_single_volume(image_batch, label_batch, model, classes= num_classes)
                    metric_list += np.array(metric_i)
                
                metric_list = metric_list / len(db_val)
                
    
                for class_i in range(num_classes - 1 ): 
                    writer.add_scalar('info/val_{}_dice'.format(class_i + 1), metric_list[class_i, 0], iter_num)
                    writer.add_scalar('infor/val_{}_hd'.format(class_i + 1), metric_list[class_i, 1], iter_num)
                
                performance = np.mean(metric_list, axis=0)[0]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)
            
                if performance > best_performance: 
                    best_performance = performance
                    save_model_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance,4)))
                    save_best_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
                    save_net_opt(model, optimizer, save_model_path)
                    save_net_opt(model, optimizer, save_best_path)
                
                logging.info('iteration %d : mean dice : %f'%(iter_num, performance))
                model.train() 
            
            if iter_num >= max_iterations: 
                break 
        
        if iter_num >= max_iterations: 
            iterator.close() 
            break 
    
    writer.close()

# self train (BCP + MAE)

In [21]:
def self_train(args ,pre_snapshot_path, snapshot_path):
    base_lr = args.base_lr
    num_classes = args.num_classes
    max_iterations = args.selftrain_iterations
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    pre_trained_model = os.path.join(pre_snapshot_path,'{}_best_model.pth'.format(args.model))
    labeled_sub_bs, unlabeled_sub_bs = int(args.labeled_bs/2), int((args.batch_size-args.labeled_bs) / 2)

    # model
    model = MBCP_net(role='mbcp', in_chs=1, class_num=num_classes, use_recon = True).cuda()
    ema_model = MBCP_net(role='teacher', in_chs=1, class_num=num_classes, ema=True, use_recon = False).cuda()

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    db_train = ACDCDataset(base_dir=args.root_dir,
                            split="train_lab",
                            transform=transforms.Compose([RandomGenerator(args.img_size)]))
    
    db_mae = ACDCDataset(base_dir=args.root_dir, split='reconstruct',
                         transform=transforms.Compose([RandomGenerator(args.img_size)]))


    db_val = ACDCDataset(base_dir=args.root_dir, split="val")


    total_slices = len(db_train)
    labeled_slice = patients_to_slices(args.root_dir,args.label_num)
    print("Total slices is: {}, labeled slices is:{}".format(total_slices, labeled_slice))
    labeled_idxs = list(range(0, labeled_slice))
    unlabeled_idxs = list(range(labeled_slice, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size-args.labeled_bs)

    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)

    mae_batch, _ = next(iter(DataLoader(db_mae, batch_size=args.batch_size, shuffle=True)))

    optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.0001)
    load_net(ema_model, pre_trained_model)
    load_net_opt(model, optimizer, pre_trained_model) #
    logging.info("Loaded from {}".format(pre_trained_model))

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("Start self_training")
    logging.info("{} iterations per epoch".format(len(trainloader)))

    model.train()
    ema_model.train()

    ce_loss = nn.CrossEntropyLoss()

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_performance = 0.0
    best_hd = 100
    iterator = tqdm(range(max_epoch), ncols=70)
    for _ in iterator:
        for _, sampled_batch in enumerate(trainloader):
            
            volume_batch, label_batch = sampled_batch
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            
            # mae input
            mask = gen_random_mask(mae_batch, patch_size=args.patch_size, mask_ratio=args.mask_ratio)
            mask_up = upsample_mask(mask, patch_size=args.patch_size, H = mae_batch.shape[2], W = mae_batch.shape[3])
            masked_img = mae_batch * (1 - mask_up)

            img_a, img_b = volume_batch[:labeled_sub_bs], volume_batch[labeled_sub_bs:args.labeled_bs]
            uimg_a, uimg_b = volume_batch[args.labeled_bs:args.labeled_bs + unlabeled_sub_bs], volume_batch[args.labeled_bs + unlabeled_sub_bs:]
            ulab_a, ulab_b = label_batch[args.labeled_bs:args.labeled_bs + unlabeled_sub_bs], label_batch[args.labeled_bs + unlabeled_sub_bs:]
            lab_a, lab_b = label_batch[:labeled_sub_bs], label_batch[labeled_sub_bs:args.labeled_bs]
            with torch.no_grad():

                
                pre_a = ema_model(uimg_a) # pseudo label a 
                pre_b = ema_model(uimg_b) # pseudo label b 
                plab_a = get_ACDC_masks(pre_a, nms=1)
                plab_b = get_ACDC_masks(pre_b, nms=1)
                img_mask, loss_mask = generate_mask(img_a)
                unl_label = ulab_a * img_mask + lab_a * (1 - img_mask)
                l_label = lab_b * img_mask + ulab_b * (1 - img_mask)
            consistency_weight = get_current_consistency_weight(iter_num//150,args)

            net_input_unl = uimg_a * img_mask + img_a * (1 - img_mask)
            net_input_l = img_b * img_mask + uimg_b * (1 - img_mask)
            out_unl = model.forward_segmentation(net_input_unl)
            out_l = model.forward_reconstruction(net_input_l)

            # forward mae
            rec_out = model.forward_reconstruction(masked_img.cuda())


            rec_loss = reconstruction_loss(rec_out, mae_batch, mask_up)


            unl_dice, unl_ce = mix_loss(out_unl, plab_a, lab_a, loss_mask, u_weight=args.u_weight, unlab=True)
            l_dice, l_ce = mix_loss(out_l, lab_b, plab_b, loss_mask, u_weight=args.u_weight)


            loss_ce = unl_ce + l_ce 
            loss_dice = unl_dice + l_dice

            total_loss = 1 * (loss_dice + loss_ce) + 1 * rec_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            iter_num += 1
            update_model_ema(model, ema_model, 0.99)

            writer.add_scalar('info/total_loss', total_loss, iter_num)
            writer.add_scalar('info/mix_dice', loss_dice, iter_num)
            writer.add_scalar('info/mix_ce', loss_ce, iter_num)
            writer.add_scalar('info/consistency_weight', consistency_weight, iter_num)     

            logging.info('iteration %d: loss: %f, mix_dice: %f, mix_ce: %f'%(iter_num, total_loss, loss_dice, loss_ce))
                
            if iter_num % 20 == 0:
                image = net_input_unl[1, 0:1, :, :]
                writer.add_image('train/Un_Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(out_unl, dim=1), dim=1, keepdim=True)
                writer.add_image('train/Un_Prediction', outputs[1, ...] * 50, iter_num)
                labs = unl_label[1, ...].unsqueeze(0) * 50
                writer.add_image('train/Un_GroundTruth', labs, iter_num)

                image_l = net_input_l[1, 0:1, :, :]
                writer.add_image('train/L_Image', image_l, iter_num)
                outputs_l = torch.argmax(torch.softmax(out_l, dim=1), dim=1, keepdim=True)
                writer.add_image('train/L_Prediction', outputs_l[1, ...] * 50, iter_num)
                labs_l = l_label[1, ...].unsqueeze(0) * 50
                writer.add_image('train/L_GroundTruth', labs_l, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                metric_list = 0.0
                for _, sampled_batch in enumerate(valloader):
                    metric_i = test_single_volume(sampled_batch["image"], sampled_batch["label"], model, classes= num_classes)
                    metric_list += np.array(metric_i)
                metric_list = metric_list / len(db_val)
                print(f'Metric list: {metric_list}') 
                for class_i in range(num_classes-1):
                    writer.add_scalar('info/val_{}_dice'.format(class_i+1), metric_list[class_i, 0], iter_num)
                    writer.add_scalar('info/val_{}_hd95'.format(class_i+1), metric_list[class_i, 1], iter_num)

                performance = np.mean(metric_list, axis=0)[0]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance, 4)))
                    save_best_path = os.path.join(snapshot_path,'{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best_path)

                logging.info('iteration %d : mean_dice : %f' % (iter_num, performance))
                model.train()

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()


In [22]:
# Trainig process 
if args.deterministic: 
    cudnn.benchmark = False 
    cudnn.deterministic = True 
    random.seed(args.seed) 
    np.random.seed(args.seed) 
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

# Path 
pre_snapshot_path = "./model/BCP/ACDC_{}_{}_labeled/pretrain".format(args.exp, args.label_num)
self_snapshot_path = "./model/BCP/ACDC_{}_{}_labeled/selftrain".format(args.exp, args.label_num)

print(f'Pretrain log path: {pre_snapshot_path + "/log.txt"}')
print(f'Self-train log path: {self_snapshot_path + "/log.txt"}')

for snapshot_path in [pre_snapshot_path, self_snapshot_path]: 
    if not os.path.exists(snapshot_path): 
        os.makedirs(snapshot_path, exist_ok= True)
#Pre_train
logging.basicConfig(filename=pre_snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
pre_train(args, pre_snapshot_path)

#Self_train
logging.basicConfig(filename=self_snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
self_train(args, pre_snapshot_path, self_snapshot_path)

Pretrain log path: ./model/BCP/ACDC_MBCP_7_labeled/pretrain/log.txt
Self-train log path: ./model/BCP/ACDC_MBCP_7_labeled/selftrain/log.txt
Mode: train_lab: 136 samples in total
Mode: val: 20 samples in total
Total slice is 1360, Labeled slice is 136
Start pre-training
11 iterations per epoch


  0%|                                           | 0/5 [00:00<?, ?it/s]

iteration 1: loss 2.307791, mix_dice: 1.661705, mix_ce: 2.953876
iteration 2: loss 2.133703, mix_dice: 1.631261, mix_ce: 2.636145
iteration 3: loss 2.011671, mix_dice: 1.605264, mix_ce: 2.418077
iteration 4: loss 1.911551, mix_dice: 1.571022, mix_ce: 2.252081
iteration 5: loss 1.813652, mix_dice: 1.529119, mix_ce: 2.098186
iteration 6: loss 1.758866, mix_dice: 1.536966, mix_ce: 1.980766
iteration 7: loss 1.676482, mix_dice: 1.485111, mix_ce: 1.867853
iteration 8: loss 1.616066, mix_dice: 1.493103, mix_ce: 1.739028
iteration 9: loss 1.539205, mix_dice: 1.463070, mix_ce: 1.615341
iteration 10: loss 1.513194, mix_dice: 1.513221, mix_ce: 1.513168
iteration 11: loss 1.416934, mix_dice: 1.404958, mix_ce: 1.428910


 20%|███████                            | 1/5 [00:01<00:07,  1.79s/it]

iteration 12: loss 1.369832, mix_dice: 1.420684, mix_ce: 1.318980
iteration 13: loss 1.312752, mix_dice: 1.445195, mix_ce: 1.180309
iteration 14: loss 1.262157, mix_dice: 1.406788, mix_ce: 1.117525
iteration 15: loss 1.234733, mix_dice: 1.405270, mix_ce: 1.064196
iteration 16: loss 1.172330, mix_dice: 1.400735, mix_ce: 0.943925
iteration 17: loss 1.126444, mix_dice: 1.357634, mix_ce: 0.895254
iteration 18: loss 1.086958, mix_dice: 1.335052, mix_ce: 0.838865
iteration 19: loss 1.048236, mix_dice: 1.328603, mix_ce: 0.767869
iteration 20: loss 1.045889, mix_dice: 1.365126, mix_ce: 0.726652
iteration 21: loss 0.996591, mix_dice: 1.338143, mix_ce: 0.655039
iteration 22: loss 1.006549, mix_dice: 1.369876, mix_ce: 0.643222


 40%|██████████████                     | 2/5 [00:03<00:04,  1.54s/it]

iteration 23: loss 0.999990, mix_dice: 1.338724, mix_ce: 0.661256
iteration 24: loss 0.978765, mix_dice: 1.353532, mix_ce: 0.603999
iteration 25: loss 0.932210, mix_dice: 1.301591, mix_ce: 0.562829
iteration 26: loss 0.929617, mix_dice: 1.353225, mix_ce: 0.506009
iteration 27: loss 0.946498, mix_dice: 1.411516, mix_ce: 0.481480
iteration 28: loss 0.944127, mix_dice: 1.373029, mix_ce: 0.515224
iteration 29: loss 0.847859, mix_dice: 1.255918, mix_ce: 0.439801
iteration 30: loss 0.846381, mix_dice: 1.265321, mix_ce: 0.427440
iteration 31: loss 0.866975, mix_dice: 1.287335, mix_ce: 0.446616
iteration 32: loss 0.894379, mix_dice: 1.293718, mix_ce: 0.495039
iteration 33: loss 0.970430, mix_dice: 1.369877, mix_ce: 0.570984


 60%|█████████████████████              | 3/5 [00:04<00:02,  1.50s/it]

iteration 34: loss 0.898077, mix_dice: 1.328545, mix_ce: 0.467609
iteration 35: loss 0.846893, mix_dice: 1.282130, mix_ce: 0.411655
iteration 36: loss 0.801106, mix_dice: 1.216877, mix_ce: 0.385334
iteration 37: loss 0.898235, mix_dice: 1.358977, mix_ce: 0.437493
iteration 38: loss 0.872980, mix_dice: 1.358505, mix_ce: 0.387455
iteration 39: loss 0.905890, mix_dice: 1.342122, mix_ce: 0.469658
iteration 40: loss 0.896961, mix_dice: 1.345330, mix_ce: 0.448591
iteration 41: loss 0.809853, mix_dice: 1.225058, mix_ce: 0.394649
iteration 42: loss 0.851451, mix_dice: 1.294019, mix_ce: 0.408883
iteration 43: loss 0.940976, mix_dice: 1.464186, mix_ce: 0.417766
iteration 44: loss 0.784245, mix_dice: 1.237287, mix_ce: 0.331203


 80%|████████████████████████████       | 4/5 [00:06<00:01,  1.48s/it]

iteration 45: loss 0.902968, mix_dice: 1.388827, mix_ce: 0.417108
iteration 46: loss 0.916643, mix_dice: 1.379431, mix_ce: 0.453854
iteration 47: loss 0.780852, mix_dice: 1.219011, mix_ce: 0.342694
iteration 48: loss 0.841862, mix_dice: 1.359953, mix_ce: 0.323772
iteration 49: loss 0.753177, mix_dice: 1.238876, mix_ce: 0.267478
iteration 50: loss 0.881429, mix_dice: 1.320369, mix_ce: 0.442489


 80%|████████████████████████████       | 4/5 [00:07<00:01,  1.79s/it]

Mode: train_lab: 136 samples in total
Mode: reconstruct: 1902 samples in total
Mode: val: 20 samples in total
Total slices is: 1360, labeled slices is:136





Loaded from ./model/BCP/ACDC_MBCP_7_labeled/pretrain/unet_best_model.pth
Loaded from ./model/BCP/ACDC_MBCP_7_labeled/pretrain/unet_best_model.pth
Start self_training
Start self_training
11 iterations per epoch
11 iterations per epoch


  return torch.Tensor(batch_list).cuda()
  0%|                                           | 0/1 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 192.00 MiB. GPU 0 has a total capacity of 3.80 GiB of which 175.69 MiB is free. Including non-PyTorch memory, this process has 3.58 GiB memory in use. Of the allocated memory 3.44 GiB is allocated by PyTorch, and 37.51 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)