In [None]:
!nvidia-smi

Mon Nov 29 23:16:27 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/new_paper/ACDC
%run otherUnets.ipynb
%run preprocessing_2D.ipynb

In [None]:
!pip install pytorch-lightning
import pytorch_lightning as pl
from IPython.display import clear_output
import nibabel as nib
import csv
import os
import glob
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from fastprogress import master_bar, progress_bar
from torchvision import transforms
import torch.optim as optim
from itertools import product
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class HistoryLogger(pl.callbacks.Callback):
    def __init__(self, dir = "history_acdcAttUnet_CE.csv"):
        self.dir = dir
    def on_validation_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        if "loss_epoch" in metrics.keys():
            logs = {"epoch": trainer.current_epoch}
            keys = ["loss_epoch", "train_diceRV_epoch", "train_diceMYO_epoch",
                    "train_diceLV_epoch", "val_loss","val_diceRV", "val_diceMYO", "val_diceLV"
                    ]
            for key in keys:
                logs[key] = metrics[key].item()
            header = list(logs.keys())
            isFile = os.path.isfile(self.dir)
            with open(self.dir, 'a', newline='') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=header)
                if not isFile:
                    writer.writeheader()
                writer.writerow(logs)
        else:
            pass
def setDropProb(model, prob=0.01):
    for layer in model.modules():
        if isinstance(layer, DropBlock2D):
            layer.drop_prob = prob
clear_output()
############## turn off Debug APIs for Final Training############
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)

<torch.autograd.profiler.emit_nvtx at 0x7f2ad1d64650>

In [None]:
data = np.load("./dataACDCA/ACDC_train_aug160.npz")
x_train, y_train = data["image"], data["mask"]
data = np.load("./dataACDCA/ACDC_val160.npz")
x_val, y_val = data["image"], data["mask"]
data = np.load("./dataACDCA/ACDC_test160.npz")
x_test, y_test = data["image"], data["mask"]
del data

In [None]:
train_dataset = DataLoader(ACDCLoader(x_train, y_train, transform=False), batch_size=16, pin_memory=True,
                        shuffle=True, num_workers=2,
                        drop_last=True, prefetch_factor = 16)
val_dataset = DataLoader(ACDCLoader(x_val, y_val, typeData="test"), batch_size=64,
                          num_workers=2, prefetch_factor=64)
test_dataset = DataLoader(ACDCLoader(x_test, y_test, typeData="test"), batch_size=64,
                          num_workers=2, prefetch_factor=64)

In [None]:
class Segmentor(pl.LightningModule):
    def __init__(self, model = AttU_Net(drop_prob=0)):
        super().__init__()
        self.model = model
    def forward(self, x):
        return self.model(x)
    def get_metrics(self):
        # don't show the version number
        items = super().get_metrics()
        items.pop("v_num", None)
        return items

    # def _step(self, batch):
    #     image, y_true = batch
    #     y_pred = self.model(image)
    #     loss = SemiActiveLoss(device=self.device)(image, y_true, y_pred)
    #     diceRV, diceMYO, diceLV = dice_rv(y_true, y_pred), dice_myo(y_true, y_pred), dice_lv(y_true, y_pred)
    #     return loss, diceRV, diceMYO, diceLV
    def _step(self, batch):
        image, y_true = batch
        y_pred = self.model(image)
        loss = CrossEntropy(device=self.device)( y_true, y_pred)
        diceRV, diceMYO, diceLV = dice_rv(y_true, y_pred), dice_myo(y_true, y_pred), dice_lv(y_true, y_pred)
        return loss, diceRV, diceMYO, diceLV

    def training_step(self, batch, batch_idx):
        loss, diceRV, diceMYO, diceLV = self._step(batch)
        metrics = {"loss": loss, "train_diceRV": diceRV, "train_diceMYO": diceMYO, "train_diceLV": diceLV}
        self.log_dict(metrics, on_step=True, on_epoch=True, prog_bar = True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, diceRV, diceMYO, diceLV = self._step(batch)
        metrics = {"val_loss": loss, "val_diceRV": diceRV, "val_diceMYO": diceMYO, "val_diceLV": diceLV}
        self.log_dict(metrics, prog_bar = True)
        return metrics

    def test_step(self, batch, batch_idx):
        loss, diceRV, diceMYO, diceLV = self._step(batch)
        metrics = {"test_diceRV": diceRV, "test_diceMYO": diceMYO, "test_diceLV": diceLV}
        self.log_dict(metrics, prog_bar = True)
        return metrics


    def configure_optimizers(self):
        optimizer = Nadam(self.parameters(), lr=1e-3)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max",
                                                         factor = 0.5, patience=15, verbose =True)
        lr_schedulers = {"scheduler": scheduler, "monitor": "val_diceRV"}
        return [optimizer], lr_schedulers


In [None]:
segmentor = Segmentor(AttU_Net(drop_prob=0))
check_point = pl.callbacks.model_checkpoint.ModelCheckpoint("./weightAttUnet/", filename="ckpt{val_diceRV:0.4f}",
                                                            monitor="val_diceRV", mode = "max", save_top_k =5,
                                                            verbose=True, save_weights_only=True,
                                                            auto_insert_metric_name=False,)
progress_bar = pl.callbacks.TQDMProgressBar()
logger = HistoryLogger()
swa = pl.callbacks.StochasticWeightAveraging(swa_epoch_start=25)
PARAMS = {"gpus":1, "benchmark": True, "enable_progress_bar" : False,"logger":False,
        #   "callbacks" : [progress_bar],
        #    "overfit_batches" :1,
          "callbacks" : [check_point, progress_bar, logger],
          "log_every_n_steps" :1, "num_sanity_val_steps":0, "max_epochs":5,
          "precision":16,
          }

trainer = pl.Trainer(**PARAMS)

# segmentor = Segmentor.load_from_checkpoint(checkpoint_path="./weightAttUnet/ckpt0.8294.ckpt")
# segmentor = Segmentor.load_from_checkpoint(checkpoint_path="./weightAttUnet/current.ckpt")

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(segmentor, train_dataset, val_dataset)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type     | Params
-----------------------------------
0 | model | AttU_Net | 31.4 M
-----------------------------------
31.4 M    Trainable params
0         Non-trainable params
31.4 M    Total params
62.778    Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Training: 0it [00:00, ?it/s]

  self.padding, self.dilation, self.groups)


Validating: 0it [00:00, ?it/s]

Epoch 0, global step 327: val_diceRV reached 0.42700 (best 0.42700), saving model to "/content/drive/My Drive/new_paper/ACDC/weightAttUnet/ckpt0.4270.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 1, global step 655: val_diceRV reached 0.62552 (best 0.62552), saving model to "/content/drive/My Drive/new_paper/ACDC/weightAttUnet/ckpt0.6255.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 2, global step 983: val_diceRV reached 0.63477 (best 0.63477), saving model to "/content/drive/My Drive/new_paper/ACDC/weightAttUnet/ckpt0.6348.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 3, global step 1311: val_diceRV reached 0.70002 (best 0.70002), saving model to "/content/drive/My Drive/new_paper/ACDC/weightAttUnet/ckpt0.7000.ckpt" as top 5


Validating: 0it [00:00, ?it/s]

Epoch 4, global step 1639: val_diceRV reached 0.70808 (best 0.70808), saving model to "/content/drive/My Drive/new_paper/ACDC/weightAttUnet/ckpt0.7081.ckpt" as top 5


In [None]:
trainer.save_checkpoint("./weightAttUnet/current.ckpt", weights_only=True)


In [None]:
for layer in segmentor.modules():
    if isinstance(layer, DropBlock2D):
        layer.drop_prob = 0.1

for layer in segmentor.modules():
    if isinstance(layer, DropBlock2D):
        print(layer.drop_prob)

# segmentor.configure_optimizers()

0.1
0.1


In [None]:
dice_scores = trainer.test(segmentor, test_dataset)[0]
dice_ave = np.mean(list(dice_scores.values()))
dice_ave

In [None]:
weight_path = sorted(glob.glob("./weights/*"), reverse=True)
max_score = 0
best_weight = ""
for weight in tqdm(weight_path):
    segmentor = Segmentor.load_from_checkpoint(checkpoint_path=weight)
    dice_scores = trainer.test(segmentor, test_dataset)[0]
    dice_ave = np.mean(list(dice_scores.values()))
    clear_output()
    print(dice_ave)
    if max_score < dice_ave:
        max_score = dice_ave
        best_weight = weight
print(f"best weight is: {best_weight} with {max_score}")


In [None]:
def predict(images, model, batch_size = 64):
    images = torch.as_tensor(images, dtype= torch.float32)
    y_preds = torch.zeros((images.size(0), NUM_CLASS, images.size(2), images.size(3)), device= device)
    batch_start = 0
    batch_end = batch_size
    pbar = tqdm()
    while batch_start < images.size(0):
        image = images[batch_start : batch_end]
        with torch.inference_mode():
            image = image.to(device)
            y_pred = model(image)
            y_preds[batch_start : batch_end] = y_pred
        batch_start += batch_size
        batch_end += batch_size
        pbar.update(1)
    pbar.close()
    res = y_preds.cpu().numpy()
    del y_preds
    return res

In [None]:
# x_inf = np.zeros((x_test.shape[0], 1, x_test.shape[1], x_test.shape[2]))
# for i in range(x_test.shape[0]):
#     x_inf[i, 0] = min_max_preprocess(x_test[i])
# segmentor = segmentor.to(device)
# segmentor.eval()
# y_pred = predict(x_inf, segmentor)
# torch.cuda.empty_cache()
# mask_predict = np.argmax(y_pred, axis=1)


In [None]:
def dice_volume_rv(y_true, y_pred, smooth = 1e-5):
    y_pred = torch.where(y_pred == 1, 1, 0)
    y_true = torch.where(y_true == 1, 1, 0)
    intersection = torch.sum(y_true * y_pred)
    cardinality  = torch.sum(y_true + y_pred)
    return (2. * intersection + smooth) / (cardinality + smooth)

def dice_volume_myo(y_true, y_pred, smooth = 1e-5):
    y_pred = torch.where(y_pred == 2, 1, 0)
    y_true = torch.where(y_true == 2, 1, 0)
    intersection = torch.sum(y_true * y_pred)
    cardinality  = torch.sum(y_true + y_pred)
    return (2. * intersection + smooth) / (cardinality + smooth)

def dice_volume_lv(y_true, y_pred, smooth = 1e-5):
    y_pred = torch.where(y_pred == 3, 1, 0)
    y_true = torch.where(y_true == 3, 1, 0)
    intersection = torch.sum(y_true * y_pred)
    cardinality  = torch.sum(y_true + y_pred)
    return (2. * intersection + smooth) / (cardinality + smooth)

def predict_volume(images, model, batch_size = 64):
    images = torch.as_tensor(images, dtype= torch.float32)
    y_preds = torch.zeros((images.size(0), NUM_CLASS, images.size(2), images.size(3)), device= device)
    batch_start = 0
    batch_end = batch_size
    while batch_start < images.size(0):
        image = images[batch_start : batch_end]
        with torch.inference_mode():
            image = image.to(device)
            y_pred = model(image)
            y_preds[batch_start : batch_end] = y_pred
        batch_start += batch_size
        batch_end += batch_size
    res = y_preds.cpu().numpy()
    del y_preds
    return res

In [None]:
all_files = sorted(glob.glob("./dataACDCA/training/*"))
np.random.seed(42)
np.random.shuffle(all_files)
train_count = 70
train_files = all_files[:train_count]
test_files = all_files[train_count : train_count + 10]+all_files[-10:]
valid_files = all_files[train_count + 10: train_count + 20]
x_inf = np.zeros((x_test.shape[0], 1, x_test.shape[1], x_test.shape[2]))
for i in range(x_test.shape[0]):
    x_inf[i, 0] = min_max_preprocess(x_test[i])

weight_path = sorted(glob.glob("./weightAttUnet/*"), reverse=True)
max_score = 0
best_weight = ""
# weight_path = ["./weights/ckpt0.8288.ckpt"]
for weight in tqdm(weight_path):
    segmentor = Segmentor.load_from_checkpoint(checkpoint_path=weight)
    segmentor = segmentor.to(device)
    segmentor.eval()
    y_pred = predict_volume(x_inf, segmentor)
    torch.cuda.empty_cache()
    mask_predict = np.argmax(y_pred, axis=1)
    count = 0
    list_rv, list_myo, list_lv = [], [], []
    for files in test_files:
        list_image = [x for x in glob.glob(files+"/*") if x.find('frame') != -1 and x.find('gt') == -1]
        for image_name in list_image:
            num = image_name.find("nii")
            mask_name = image_name[:num-1] +"_gt.nii.gz"
            image = nib.load(image_name).get_fdata().astype(np.uint16)
            label = nib.load(mask_name).get_fdata().astype(np.uint8)
            image = center_crop(image)
            label = center_crop(label)
            label_pred = np.zeros_like(label)
            for z in range(label.shape[-1]):
                label_pred[..., z] = mask_predict[count]
                count += 1
            label = torch.from_numpy(label)
            label_pred = torch.from_numpy(label_pred)
            list_rv.append(dice_volume_rv(label, label_pred).item())
            list_myo.append(dice_volume_myo(label, label_pred).item())
            list_lv.append(dice_volume_lv(label, label_pred).item())
    diceRV = np.mean(list_rv)
    diceMYO = np.mean(list_myo)
    diceLV = np.mean(list_lv)
    dice_ave = np.mean([diceRV, diceMYO, diceLV])
    clear_output()
    print(dice_ave)
    if max_score < dice_ave:
        max_score = dice_ave
        best_weight = weight
print(f"best weight is: {best_weight} with {max_score}")

100%|██████████| 2/2 [00:06<00:00,  3.40s/it]

0.8980089778701464
best weight is: ./weightAttUnet/ourLoss.ckpt with 0.9165331775943439





In [None]:
print(diceRV, diceMYO, diceLV, np.mean([diceRV, diceMYO, diceLV]))

0.8940340965986252 0.8656079322099686 0.9343849048018456 0.8980089778701464


In [None]:
# all_files = sorted(glob.glob("./dataACDCA/training/*"))
# np.random.seed(42)
# np.random.shuffle(all_files)
# train_count = 70
# train_files = all_files[:train_count]
# test_files = all_files[train_count : train_count + 10]+all_files[-10:]
# valid_files = all_files[train_count + 10: train_count + 20]
# x_inf = np.zeros((x_test.shape[0], 1, x_test.shape[1], x_test.shape[2]))
# for i in range(x_test.shape[0]):
#     x_inf[i, 0] = min_max_preprocess(x_test[i])
# weight_path = sorted(glob.glob("./weights/*"), reverse=True)
# # weight_path = ["./weights/ckpt0.8533.ckpt", "./weights/ckpt0.8288.ckpt"]
# max_score = 0
# best_weight = ""
# for weight in tqdm(weight_path):
#     segmentor = Segmentor.load_from_checkpoint(checkpoint_path=weight)
#     segmentor = segmentor.to(device)
#     segmentor.eval()
#     y_pred = predict_volume(x_inf, segmentor)
#     torch.cuda.empty_cache()
#     mask_predict = np.argmax(y_pred, axis=1)
#     count = 0
#     dice_rv_all, dice_myo_all, dice_lv_all = [], [], []
#     for files in test_files:
#         list_rv, list_myo, list_lv = [], [], []
#         list_image = [x for x in glob.glob(files+"/*") if x.find('frame') != -1 and x.find('gt') == -1]
#         for image_name in list_image:
#             num = image_name.find("nii")
#             mask_name = image_name[:num-1] +"_gt.nii.gz"
#             image = nib.load(image_name).get_fdata().astype(np.uint16)
#             label = nib.load(mask_name).get_fdata().astype(np.uint8)
#             image = center_crop(image)
#             label = center_crop(label)
#             for z in range(label.shape[-1]):
#                 sub_label = torch.from_numpy(label[...,z])
#                 label_pred = torch.from_numpy(mask_predict[count])
#                 list_rv.append(dice_volume_rv(sub_label, label_pred).item())
#                 list_myo.append(dice_volume_myo(sub_label, label_pred).item())
#                 list_lv.append(dice_volume_lv(sub_label, label_pred).item())
#                 count += 1
#             dice_rv_all.append(np.mean(list_rv))
#             dice_myo_all.append(np.mean(list_myo))
#             dice_lv_all.append(np.mean(list_lv))
#     dice_RV = np.mean(dice_rv_all)
#     dice_MYO = np.mean(dice_myo_all)
#     dice_LV = np.mean(dice_lv_all)
#     dice_ave = np.mean([dice_RV, dice_MYO, dice_LV])
#     print(dice_ave)




In [None]:
# np.savez_compressed("ACDC_predict_mask", predicted_mask=mask_predict)

In [None]:
for i in range(x_test.shape[0]):
    plt.figure(i+1)
    plt.subplot(131), plt.imshow(x_test[i])
    plt.subplot(132), plt.imshow(y_test[i])
    plt.subplot(133), plt.imshow(mask_predict[i])
