In [1]:
import torch 
from dataset.MMWHS import get_datasets,get_datasets_noPad,get_datasets_Aug
from nnunet_mednext import MedNeXt , create_mednext_v1
from utils import reload_ckpt_bis , post_trans , decollate_batch , inference
import nibabel as nib
import numpy as np
from monai import transforms
import monai 


In [2]:
val_transform = transforms.Compose(
        [
            transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ]
    )

full_train_dataset, l_val_dataset, bench_dataset = get_datasets_Aug(1234, fold_number=0,val_transforms=val_transform)
train_loader = torch.utils.data.DataLoader(full_train_dataset, batch_size=1, shuffle=True,
                                            num_workers=0, pin_memory=False, drop_last=False)
val_loader = torch.utils.data.DataLoader(l_val_dataset, batch_size=1, shuffle=False,
                                            pin_memory=True, num_workers=0)
bench_loader = torch.utils.data.DataLoader(bench_dataset, batch_size=1, num_workers=0)


In [3]:
model_1 = create_mednext_v1(
        num_input_channels = 2,
        model_id = 'S',
        num_classes = 8,
    )
checkpoint = "/home/fanxx/fxx/Multi-modal-Segmentation/MMWHS_pre/Multi_modal/MedNeXt_S/runs/logs_base/model_noPad_2/model_lower_loss.pth.tar"
reload_ckpt_bis(checkpoint, model_1,device='cuda:0')
model_1.cuda()


=> loading checkpoint /home/fanxx/fxx/Multi-modal-Segmentation/MMWHS_pre/Multi_modal/MedNeXt_S/runs/logs_base/model_noPad_2/model_lower_loss.pth.tar


MedNeXt(
  (stem): Conv3d(2, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (enc_block_0): Sequential(
    (0): MedNeXtBlock(
      (conv1): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=32)
      (norm): GroupNorm(32, 32, eps=1e-05, affine=True)
      (conv2): Conv3d(32, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (act): GELU(approximate='none')
      (conv3): Conv3d(64, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    )
    (1): MedNeXtBlock(
      (conv1): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=32)
      (norm): GroupNorm(32, 32, eps=1e-05, affine=True)
      (conv2): Conv3d(32, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (act): GELU(approximate='none')
      (conv3): Conv3d(64, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    )
  )
  (down_0): MedNeXtDownBlock(
    (conv1): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), groups=32)
    (norm): GroupNorm(32, 32, ep

In [4]:
def meandice(pred, label , num_class):
    sumdice = 0
    smooth = 1e-6

    for i in range(1, num_class):
        pred_bin = (pred==i)*1
        label_bin = (label==i)*1

        pred_bin = pred_bin.contiguous().view(pred_bin.shape[0], -1)
        label_bin = label_bin.contiguous().view(label_bin.shape[0], -1)

        intersection = (pred_bin * label_bin).sum()
        dice = (2. * intersection + smooth) / (pred_bin.sum() + label_bin.sum() + smooth)
        sumdice += dice

    return sumdice/(num_class - 1) # 有背景类

In [7]:

mean_dice =[]
mean_hd95 = []
mean_MIOU = []
with torch.no_grad():
    for i , batch in enumerate(val_loader):
        patient_id = batch["patient_id"]
        val_inputs, val_labels = (
                batch["image"].cuda(),
                batch["label"].cuda(),
            )
        # print(patient_id)
        # val_outputs = inference(val_inputs, model_1)
        val_outputs = model_1(val_inputs)
    
        val_outputs_1 = [post_trans(i) for i in decollate_batch(val_outputs)]

        val_outputs_2 = torch.nn.functional.one_hot(torch.argmax(torch.softmax(val_outputs ,dim =1) , dim = 1) ).permute(0,4,1,2,3)

        nib.save(nib.Nifti1Image(val_inputs[0,0,:,:,:].float().detach().cpu().numpy(),np.eye(4)) , filename=f"./output/{patient_id[0]}_ct.nii.gz")
        nib.save(nib.Nifti1Image(val_inputs[0,1,:,:,:].float().detach().cpu().numpy(),np.eye(4)) , filename=f"./output/{patient_id[0]}_mr.nii.gz")
        labels = ['backgroud','CT-A', 'CT-B', 'CT-C', 'CT-D', 'CT-E', 'CT-F', 'CT-G']
        dice_metric = meandice(torch.argmax(torch.softmax(val_outputs ,dim =1) , dim = 1) , torch.argmax(val_labels.int(),dim =1) , 8)
        # print(dice_metric)
        mean_dice.append(dice_metric.item())
        mean_hd95.append(monai.metrics.HausdorffDistanceMetric(include_background = False,percentile=95)(val_outputs_2 , val_labels).mean().item())
        mean_MIOU.append(monai.metrics.MeanIoU(include_background=False)(val_outputs_2 , val_labels).mean().item())

        for i in range(7):
            
            nib.save(nib.Nifti1Image((val_outputs_1[0][i,:,:,:].detach().cpu()>0.5).int().numpy(),np.eye(4)) , filename=f"./output/{patient_id[0]}_{labels[i]}.nii.gz")
        
        for i in range(7):
            nib.save(nib.Nifti1Image(val_labels[0,i,:,:,:].int().detach().cpu().numpy(),np.eye(4)) , filename=f"./output/{patient_id[0]}_{labels[i]}_gt.nii.gz")
        
        nib.save(nib.Nifti1Image(torch.argmax(torch.softmax(val_outputs ,dim =1) , dim = 1)[0,:,:,:].int().detach().cpu().numpy(),np.eye(4)) , filename=f"./output/{patient_id[0]}_pred.nii.gz")
        nib.save(nib.Nifti1Image( torch.argmax(val_labels.int(),dim =1)[0,:,:,:].int().detach().cpu().numpy(),np.eye(4)) , filename=f"./output/{patient_id[0]}_gt.nii.gz")

        # print(val_outputs.shape)
        # break
    
print(f"dice: {mean_dice} mean is : {np.mean(mean_dice) *100}")
print(f"MIOU: {mean_MIOU} mean is : {np.mean(mean_MIOU) *100}")
print(f"hd95: {mean_hd95} mean is : {np.mean(mean_hd95)}")


dice: [0.8348702192306519, 0.8199483752250671] mean is : 82.74092972278595
MIOU: [0.7252334952354431, 0.7044278979301453] mean is : 71.48306965827942
hd95: [8.87491512298584, 11.322295188903809] mean is : 10.098605155944824


In [7]:
torch.argmax(torch.softmax(val_outputs ,dim =1) , dim = 1).shape

torch.Size([1, 128, 128, 128])

In [33]:
val_outputs_1[0].shape

torch.Size([7, 128, 128, 128])

In [41]:
for i in range(7):
    print(val_outputs_1[0][i,:,:,:].max())

metatensor(0.9492, device='cuda:0')
metatensor(0.8301, device='cuda:0')
metatensor(0.9424, device='cuda:0')
metatensor(0.8032, device='cuda:0')
metatensor(0.9282, device='cuda:0')
metatensor(0.9028, device='cuda:0')
metatensor(0.9263, device='cuda:0')


In [42]:
temp_dict = {
    "image" : ["/home/fanxx/fxx/sdc/luoluo/MMWHS/MMWHS/ct_train/ct_train_1001_image.nii.gz",
                "/home/fanxx/fxx/sdc/luoluo/MMWHS/MMWHS/mr_train/mr_train_1001_image.nii.gz"],
    "label" : ["/home/fanxx/fxx/sdc/luoluo/MMWHS/MMWHS/ct_train/ct_train_1001_label.nii.gz",
                "/home/fanxx/fxx/sdc/luoluo/MMWHS/MMWHS/mr_train/mr_train_1001_label.nii.gz"]
}

In [44]:
from monai import transforms

In [50]:
roi = (128,128,128)
train_transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image", "label"]),
            # transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            # transforms.CropForegroundd(
            #     keys=["image", "label"],
            #     source_key="image",
            #     k_divisible=[roi[0], roi[1], roi[2]],
            # ),
            # transforms.RandSpatialCropd(
            #     keys=["image", "label"],
            #     roi_size=[roi[0], roi[1], roi[2]],
            #     random_size=False,
            # ),
            # transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            # transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
            # transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
            # transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            # transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
            # transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
        ]
    )

