In [None]:
%sh
/databricks/python/bin/pip install loguru awscli nibabel

Collecting loguru
  Downloading loguru-0.6.0-py3-none-any.whl (58 kB)
Collecting awscli
  Downloading awscli-1.27.4-py3-none-any.whl (3.9 MB)
Collecting nibabel
  Downloading nibabel-4.0.2-py3-none-any.whl (3.3 MB)
Collecting PyYAML<5.5,>=3.10
  Downloading PyYAML-5.4.1-cp39-cp39-manylinux1_x86_64.whl (630 kB)
Collecting s3transfer<0.7.0,>=0.6.0
  Downloading s3transfer-0.6.0-py3-none-any.whl (79 kB)
Collecting docutils<0.17,>=0.10
  Downloading docutils-0.16-py2.py3-none-any.whl (548 kB)
Collecting rsa<4.8,>=3.1.2
  Downloading rsa-4.7.2-py3-none-any.whl (34 kB)
Collecting colorama<0.4.5,>=0.2.5
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Collecting botocore==1.29.4
  Downloading botocore-1.29.4-py3-none-any.whl (9.8 MB)
Installing collected packages: botocore, s3transfer, rsa, PyYAML, docutils, colorama, nibabel, loguru, awscli
  Attempting uninstall: botocore
    Found existing installation: botocore 1.24.18
    Uninstalling botocore-1.24.18:
      Successfully uninsta

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch import cat
from loguru import logger
# import deepspeed
# from apex import amp
from torch.utils.data import Dataset
from scipy.ndimage import zoom
from tensorflow.keras.utils import to_categorical
import os, time, sys
from torch.cuda.amp import autocast, GradScaler

# (1) base version model

In [None]:
# 3D-UNet model.
# x: 128x128 resolution for 32 frames.
# https://github.com/huangzhii/FCN-3D-pytorch/blob/master/main3d.py
import torch
import torch.nn as nn
import os
import numpy as np
from collections import OrderedDict

def passthrough(x, **kwargs):
    return x


def ELUCons(elu, nchan):
    if elu:
        return nn.ELU(inplace=True)
    else:
        return nn.PReLU(nchan)


class LUConv(nn.Module):
    def __init__(self, nchan, elu):
        super(LUConv, self).__init__()
        self.relu1 = ELUCons(elu, nchan)
        self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)

        self.bn1 = torch.nn.BatchNorm3d(nchan)

    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        return out


def _make_nConv(nchan, depth, elu):
    layers = []
    for _ in range(depth):
        layers.append(LUConv(nchan, elu))
    return nn.Sequential(*layers)


class InputTransition(nn.Module):
    def __init__(self, in_channels, elu):
        super(InputTransition, self).__init__()
        self.num_features = 16
        self.in_channels = in_channels

        self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2)

        self.bn1 = torch.nn.BatchNorm3d(self.num_features)

        self.relu1 = ELUCons(elu, self.num_features)

    def forward(self, x):
        out = self.conv1(x)
        repeat_rate = int(self.num_features / self.in_channels)
        out = self.bn1(out)
        x16 = x.repeat(1, repeat_rate, 1, 1, 1)
        return self.relu1(torch.add(out, x16))


class DownTransition(nn.Module):
    def __init__(self, inChans, nConvs, elu, dropout=False):
        super(DownTransition, self).__init__()
        outChans = 2 * inChans
        self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
        self.bn1 = torch.nn.BatchNorm3d(outChans)

        self.do1 = passthrough
        self.relu1 = ELUCons(elu, outChans)
        self.relu2 = ELUCons(elu, outChans)
        if dropout:
            self.do1 = nn.Dropout3d()
        self.ops = _make_nConv(outChans, nConvs, elu)

    def forward(self, x):
        down = self.relu1(self.bn1(self.down_conv(x)))
        out = self.do1(down)
        out = self.ops(out)
        out = self.relu2(torch.add(out, down))
        return out


class UpTransition(nn.Module):
    def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
        super(UpTransition, self).__init__()
        self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)

        self.bn1 = torch.nn.BatchNorm3d(outChans // 2)
        self.do1 = passthrough
        self.do2 = nn.Dropout3d()
        self.relu1 = ELUCons(elu, outChans // 2)
        self.relu2 = ELUCons(elu, outChans)
        if dropout:
            self.do1 = nn.Dropout3d()
        self.ops = _make_nConv(outChans, nConvs, elu)

    def forward(self, x, skipx):
        out = self.do1(x)
        skipxdo = self.do2(skipx)
        out = self.relu1(self.bn1(self.up_conv(out)))
        xcat = torch.cat((out, skipxdo), 1)
        out = self.ops(xcat)
        out = self.relu2(torch.add(out, xcat))
        return out


class OutputTransition(nn.Module):
    def __init__(self, in_channels, classes, elu):
        super(OutputTransition, self).__init__()
        self.classes = classes
        self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)
        self.bn1 = torch.nn.BatchNorm3d(classes)

        self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)
        self.relu1 = ELUCons(elu, classes)

    def forward(self, x):
        # convolve 32 down to channels as the desired classes
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.conv2(out)
        return out

class UpTransitionNoConv(nn.Module):
    def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
        # inChans=62, outChans=32, nConvs=1, elu=True
        super(UpTransitionNoConv, self).__init__()
        self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)

        self.bn1 = torch.nn.BatchNorm3d(outChans // 2)
        self.do1 = passthrough
        self.do2 = nn.Dropout3d()
        self.relu1 = ELUCons(elu, outChans // 2)
        if dropout:
            self.do1 = nn.Dropout3d()
        

    def forward(self, x, skipx):
        out = self.do1(x)
        skipxdo = self.do2(skipx)
        out = self.relu1(self.bn1(self.up_conv(out)))
        xcat = torch.cat((out, skipxdo), 1)
        return xcat


# --- 1.基础的vnet3d

In [None]:
class VNet(nn.Module):
    """
    Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797
    """

    def __init__(self, elu=True, in_channels=1, classes=1):
        super(VNet, self).__init__()
        self.classes = classes
        self.in_channels = in_channels

        self.in_tr = InputTransition(in_channels, elu=elu)
        self.down_tr32 = DownTransition(16, 1, elu)
        self.down_tr64 = DownTransition(32, 2, elu)
        self.down_tr128 = DownTransition(64, 3, elu, dropout=False)
        self.down_tr256 = DownTransition(128, 2, elu, dropout=False)
        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)
        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)
        self.up_tr64 = UpTransition(128, 64, 1, elu)
        self.up_tr32 = UpTransition(64, 32, 1, elu)
        self.out_tr = OutputTransition(32, classes, elu)


    def forward(self, x):
        out16 = self.in_tr(x)
        out32 = self.down_tr32(out16)
        out64 = self.down_tr64(out32)
        out128 = self.down_tr128(out64)
        out256 = self.down_tr256(out128)
        out = self.up_tr256(out256, out128)
        out = self.up_tr128(out, out64)
        out = self.up_tr64(out, out32)
        out = self.up_tr32(out, out16)
        out = self.out_tr(out)
        return out


# --- 2.模型并行VNet3d版本

In [None]:
devices = ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3', 'cpu']

class VNet_Parallelism(nn.Module):
    """
    Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797
    """

    def __init__(self, elu=True, in_channels=1, classes=1):
        super(VNet_Parallelism, self).__init__()
        self.classes = classes
        self.in_channels = in_channels
        logger.info("begin initialize model struction")
        self.in_tr = InputTransition(in_channels, elu=elu).to(devices[0])
        self.down_tr32 = DownTransition(16, 1, elu).to(devices[0])
        self.down_tr64 = DownTransition(32, 2, elu).to(devices[0])
        self.down_tr128 = DownTransition(64, 3, elu, dropout=False).to(devices[1])
        self.down_tr256 = DownTransition(128, 2, elu, dropout=False).to(devices[1])
        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False).to(devices[1])        
        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False).to(devices[2])
        self.up_tr64 = UpTransition(128, 64, 1, elu).to(devices[2])
        # 版本1
        # self.up_tr32 = UpTransition(64, 32, 1, elu).to(devices[3])
        
        # 版本2
        self.up_tr32 = UpTransitionNoConv(64, 32, 1, elu).to(devices[3])
        self.up_tr32_ops = _make_nConv(32, 1, elu).to(devices[2])
        self.up_tr32_relu2 = ELUCons(elu, 32).to(devices[2])

        
        self.out_tr = OutputTransition(32, classes, elu).to(devices[1])



    def forward(self, x):
        out16 = self.in_tr(x.to(devices[0]))
        logger.debug(f"out16: {out16.shape}")
        out32 = self.down_tr32(out16)
        logger.debug(f"out32: {out32.shape}")
        out64 = self.down_tr64(out32)
        # print(f"out64: {out64.shape}, self.down_tr128: {next(self.down_tr128.parameters()).device}")
        out128 = self.down_tr128(out64.to(devices[1]))
        logger.debug(f"out128: {out128.shape}")
        out256 = self.down_tr256(out128)
        logger.debug(f"out256: {out256.shape}, out128: {out128.shape}")
        out = self.up_tr256(out256, out128)
        logger.debug(f"out: {out.shape}， out64: {out64.shape}")
        out = self.up_tr128(out.to(devices[2]), out64.to(devices[2]))
        logger.debug(f"out: {out.shape}, out32: {out32.shape}")
        out = self.up_tr64(out.to(devices[2]), out32.to(devices[2]))
        logger.debug(f"out: {out.shape}, out16: {out16.shape}")
        # 版本1
        # out = self.up_tr32(out.to(devices[3]), out16.to(devices[3]))

        # 版本2
        out_uptr32 = self.up_tr32(out.to(devices[3]), out16.to(devices[3]))
        out_uptr32_ops_val = self.up_tr32_ops(out_uptr32.to(devices[2]))
        out = self.up_tr32_relu2(torch.add(out_uptr32_ops_val, out_uptr32.to(devices[2])))

        logger.debug(f"out: {out.shape}")
        out = self.out_tr(out.to(devices[1]))
        logger.debug(f"out: {out.shape}")
        return out


# --- 3. version 2

In [None]:
devices = ['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3', 'cpu']

class VNet_Parallelism(nn.Module):
    """
    Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797
    """

    def __init__(self, elu=True, in_channels=1, classes=1):
        super(VNet_Parallelism, self).__init__()
        self.classes = classes
        self.in_channels = in_channels
        logger.info("begin initialize model struction")
        self.in_tr = InputTransition(in_channels, elu=elu).to(devices[0])
        self.down_tr32 = DownTransition(16, 1, elu).to(devices[0])
        self.down_tr64 = DownTransition(32, 2, elu).to(devices[0])
        self.down_tr128 = DownTransition(64, 3, elu, dropout=False).to(devices[1])
        self.down_tr256 = DownTransition(128, 2, elu, dropout=False).to(devices[1])
        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False).to(devices[1])        
        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False).to(devices[2])
        self.up_tr64 = UpTransition(128, 64, 1, elu).to(devices[2])
        # 版本1
        # self.up_tr32 = UpTransition(64, 32, 1, elu).to(devices[3])
        
        # 版本2
        self.up_tr32 = UpTransitionNoConv(64, 32, 1, elu).to(devices[3])
        self.up_tr32_ops = _make_nConv(32, 1, elu).to(devices[2])
        self.up_tr32_relu2 = ELUCons(elu, 32).to(devices[0])

        
        self.out_tr = OutputTransition(32, classes, elu).to(devices[1])



    def forward(self, x):
        out16 = self.in_tr(x.to(devices[0]))
        logger.debug(f"out16: {out16.shape}")
        out32 = self.down_tr32(out16)
        logger.debug(f"out32: {out32.shape}")
        out64 = self.down_tr64(out32)
        # print(f"out64: {out64.shape}, self.down_tr128: {next(self.down_tr128.parameters()).device}")
        out128 = self.down_tr128(out64.to(devices[1]))
        logger.debug(f"out128: {out128.shape}")
        out256 = self.down_tr256(out128)
        logger.debug(f"out256: {out256.shape}, out128: {out128.shape}")
        out = self.up_tr256(out256, out128)
        logger.debug(f"out: {out.shape}， out64: {out64.shape}")
        out = self.up_tr128(out.to(devices[2]), out64.to(devices[2]))
        logger.debug(f"out: {out.shape}, out32: {out32.shape}")
        out = self.up_tr64(out.to(devices[2]), out32.to(devices[2]))
        logger.debug(f"out: {out.shape}, out16: {out16.shape}")
        # 版本1
        # out = self.up_tr32(out.to(devices[3]), out16.to(devices[3]))

        # 版本2
        out_uptr32 = self.up_tr32(out.to(devices[3]), out16.to(devices[3]))
        out_uptr32_ops_val = self.up_tr32_ops(out_uptr32.to(devices[2]))
        out = self.up_tr32_relu2(torch.add(out_uptr32_ops_val.to(devices[0]), out_uptr32.to(devices[0])))

        logger.debug(f"out: {out.shape}")
        out = self.out_tr(out.to(devices[1]))
        logger.debug(f"out: {out.shape}")
        return out


--- cpu loading

In [None]:
class VNet_Parallelism(nn.Module):
    """
    Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797
    """

    def __init__(self, elu=True, in_channels=1, classes=1):
        super(VNet_Parallelism, self).__init__()
        self.classes = classes
        self.in_channels = in_channels
        logger.info("begin initialize model struction")
        self.in_tr = InputTransition(in_channels, elu=elu)
        self.down_tr32 = DownTransition(16, 1, elu)
        self.down_tr64 = DownTransition(32, 2, elu)
        self.down_tr128 = DownTransition(64, 3, elu, dropout=False)
        self.down_tr256 = DownTransition(128, 2, elu, dropout=False)
        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)       
        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)
        self.up_tr64 = UpTransition(128, 64, 1, elu)
        # 版本1
        # self.up_tr32 = UpTransition(64, 32, 1, elu).to(devices[3])
        
        # 版本2
        self.up_tr32 = UpTransitionNoConv(64, 32, 1, elu)
        self.up_tr32_ops = _make_nConv(32, 1, elu)
        self.up_tr32_relu2 = ELUCons(elu, 32)

        
        self.out_tr = OutputTransition(32, classes, elu)



    def forward(self, x):
        out16 = self.in_tr(x)
        logger.debug(f"out16: {out16.shape}")
        out32 = self.down_tr32(out16)
        logger.debug(f"out32: {out32.shape}")
        out64 = self.down_tr64(out32)
        # print(f"out64: {out64.shape}, self.down_tr128: {next(self.down_tr128.parameters()).device}")
        out128 = self.down_tr128(out64)
        logger.debug(f"out128: {out128.shape}")
        out256 = self.down_tr256(out128)
        logger.debug(f"out256: {out256.shape}, out128: {out128.shape}")
        out = self.up_tr256(out256, out128)
        logger.debug(f"out: {out.shape}， out64: {out64.shape}")
        out = self.up_tr128(out, out64)
        logger.debug(f"out: {out.shape}, out32: {out32.shape}")
        out = self.up_tr64(out, out32)
        logger.debug(f"out: {out.shape}, out16: {out16.shape}")
        # 版本1
        # out = self.up_tr32(out.to(devices[3]), out16.to(devices[3]))

        # 版本2
        out_uptr32 = self.up_tr32(out, out16)
        out_uptr32_ops_val = self.up_tr32_ops(out_uptr32)
        out = self.up_tr32_relu2(torch.add(out_uptr32_ops_val, out_uptr32))

        logger.debug(f"out: {out.shape}")
        out = self.out_tr(out)
        logger.debug(f"out: {out.shape}")
        return out


# (2) data generator

In [None]:
class SliceAndResize:
    
    @staticmethod
    def mask2onehot(mask:np.array, num_classes:int):
        """
        Converts a segmentation mask (H,W) to (channel,H,W) where the last dim is a one
        hot encoding vector
        Convert class index tensor to one hot encoding tensor.
        Args:
             input: A tensor of shape [1, *], * represents a complete dimension, [channel, height, width, length]
             num_classes: An int of number of class
        Returns:
            A tensor of shape [num_classes, *]

        """
        _mask = [mask == i for i in range(num_classes)]
        return np.array(_mask).astype(np.uint8)

    
    @staticmethod
    def resize_images(img: np.array, resize_ratio=None):
        if resize_ratio is None:
            resize_ratio = [0.5] * len(img.shape)
        logger.debug(f"resize ratio: {resize_ratio}")
        img = zoom(img, resize_ratio, order=0, mode='nearest')
        return img
    
    @staticmethod
    def slice_and_resize(x:np.array, y:np.array, slice_interval:list, resize_shape:list, num_classes:int):
        logger.debug(f"x: {x.shape}, y: {y.shape}, slice_interval: {slice_interval}, resize_shape: {resize_shape}, num_classes: {num_classes}")
        slice_x = x
        slice_y = y
#         slice_x = x[slice_interval[0]:slice_interval[1], :, :]
#         slice_y = y[slice_interval[0]:slice_interval[1], :, :]
#         slice_x = x[slice_interval[0]:slice_interval[1], 2:-2, :]
#         slice_y = y[slice_interval[0]:slice_interval[1], 2:-2, :]
        # draw(slice_x[278-slice_interval[0]], slice_y[278-slice_interval[0]])
        assert slice_x.shape == slice_y.shape, Exception("The dimensions of x and y should be the same")
        resize_ratio = [i/j for i, j in zip(resize_shape, slice_x.shape)]
#         resize_x = SliceAndResize.resize_images(slice_x, resize_ratio)
#         resize_y = SliceAndResize.resize_images(slice_y, resize_ratio)
        resize_x = slice_x
        resize_y = slice_y
        # draw(resize_x[278-slice_interval[0]], resize_y[278-slice_interval[0]])
        output_x = np.expand_dims(resize_x, axis=0)
        output_y = SliceAndResize.mask2onehot(resize_y, num_classes)
        logger.debug(f'slice x: {slice_x.shape}, slice y: {slice_y.shape}, resize_x: {resize_x.shape}, resize_y: {resize_y.shape}, output x: {output_x.shape}, output_y: {output_y.shape}')
        # draw_multi(output_x[0][278-slice_interval[0]], output_y[1][278-slice_interval[0]], output_y[2][278-slice_interval[0]])
        return output_x, output_y


In [None]:
class DataGenerator(Dataset):
    """
    define data generator based on Sequence
    """

    def __init__(self,
                 list_IDs,
                 file_path="./data/",
                 num_classes=3,
                 resize_shape:list=[208, 96, 112],
                 slice_interval:list=[220, 428],
                 to_fit=True,
                 x_prefix: str='fat', 
                 y_prefix: str='fat_label',
                 x_postfix: str = "npy", 
                 y_postfix: str = "npy"
                 ):
        """
        data generateor tool
        :param list_IDs: user id list
        :param file_path: data file path
        :param resize_shape: resize image shape 
        :param slice_interval: slice interval range
        :param num_classes: classifier number 
        :param to_fit: if True, return (x, y), if False only return y
        :param x_prefix: feature data prefix
        :param x_postfix: feature data postfix
        :param y_prefix: label data prefix
        :param y_postfix: label data postfix
        """
        self.list_IDs = list_IDs
        self.resize_shape = resize_shape
        self.slice_interval = slice_interval
        self.num_classes = num_classes
        self.file_path = file_path
        self.to_fit = to_fit
        self.x_prefix = x_prefix
        self.y_prefix = y_prefix
        self.x_postfix = x_postfix
        self.y_postfix = y_postfix

    def __len__(self):
        """
        data number
        """
#         logger.info(f'list ids number: {len(self.list_IDs)}')
        return len(self.list_IDs)
        

    def __generate_x_y_data(self, index):
        """
        generate x and y data through batch_user_file_path
        create completed x and y file path
        :param x_prefix: feature data prefix
        :param x_postfix: feature data postfix
        :param y_prefix: label data prefix
        :param y_postfix: label data postfix
        :return: all user data file path info about this batch
        :return:
        """
        ids = self.list_IDs[index]
        x_path = os.path.join(os.path.join(self.file_path, ids), f"{self.x_prefix}.{self.x_postfix}")
        y_path = os.path.join(os.path.join(self.file_path, ids), f"{self.y_prefix}.{self.y_postfix}")
        x = np.load(x_path)
        y = np.load(y_path)
#         x = np.random.random([1] + self.resize_shape).astype(np.float32)
#         y = np.random.random([3] + self.resize_shape).astype(np.float32)
#         slice = 278
        # draw(x[slice], y[slice])
        logger.debug(f'ids: {ids}, origin x shape: {x.shape}, {type(x)}， origin y shape: {y.shape}, {type(y)}')
        x, y = SliceAndResize.slice_and_resize(x=x, y=y, slice_interval=self.slice_interval, resize_shape=self.resize_shape, num_classes=self.num_classes)
        # draw_multi(x[0][slice-slice_interval[0]], y[1][slice-slice_interval[0]], y[2][slice-slice_interval[0]])

        if self.to_fit:
            logger.debug(f'train status：{self.to_fit}, ids: {ids}, x shape: {x.shape}， y shape: {y.shape}')
            return ids, x, y 
        else:
            logger.debug(f'test status：{self.to_fit}, ids: {ids}, x shape: {x.shape}')
            return ids, x

    def __getitem__(self, item):
        """
        :param item: example index
        :param x_prefix: feature data prefix
        :param x_postfix: feature data postfix
        :param y_prefix: label data prefix
        :param y_postfix: label data postfix
        :return:
        """
        logger.debug(f'example index: {item}')
        if self.to_fit:
            logger.debug('begin generate x and y')
            ids, x, y = self.__generate_x_y_data(item)
            return ids, x, y
        else:
            ids, x = self.__generate_x_y_data(item)
            return x


# (3) eval metrics

In [None]:
class EvalMetric:
    def __init__(self, dim, batch_size, n_classes, smooth=1):
        self.dim = dim
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.smooth = smooth
        # self.ones_arr = K.ones((self.batch_size, *self.dim, self.n_classes))

        
    def dice_coef(self, y_true, y_pred):
        assert len(y_true.shape) == len(y_pred.shape), Exception("y_Ture and Y_PRED has different dimensions")
        
#         y_true, y_pred = y_true.float(), y_pred.float()
        y_true, y_pred = y_true.type(torch.float32), y_pred.type(torch.float32)
        
        logger.debug(f"y_true: {y_true.shape}, y_pred: {y_pred.shape}")
        
        y_true = torch.flatten(y_true)
        y_pred = torch.flatten(y_pred)
        logger.debug(f"y_true: {y_true.shape}, y_pred: {y_pred.shape}")
        
        intersection = torch.sum(y_true * y_pred)
        val = (2. * intersection + self.smooth) / (torch.sum(y_true) + torch.sum(y_pred) + self.smooth)
#         print('jwu dice_coef intersection: ', intersection)
        logger.debug(f"intersection: {intersection}, val: {val}")
        return val

      
    def dice_coef_multilabel(self, y_true, y_pred):
        """
        exclude background
        """
#         print('jwu dice_coef_multilabel: ', torch.max(y_true), torch.min(y_true), torch.max(y_pred), torch.min(y_pred))
        
        dice = 0
        for index in range(1, self.n_classes):
            dice += self.dice_coef(y_true[:, index, :, :, :], y_pred[:, index, :, :, :])
        return dice / (self.n_classes - 1)

      
    def dice_coef_loss(self, y_true, y_pred):
        return 1 - self.dice_coef(y_true, y_pred)

      
    def tversky_loss(self, y_true, y_pred):
        alpha = 0.3
        beta = 0.7
#         y_true, y_pred = y_true.float(), y_pred.float()
        y_true, y_pred = y_true.type(torch.float32), y_pred.type(torch.float32)
#         print('jwu tversky_loss: ', torch.max(y_true), torch.min(y_true), torch.max(y_pred), torch.min(y_pred))
        
        p0 = y_pred  # proba that voxels are class i
        p1 = 1 - y_pred  # proba that voxels are not class i
        g0 = y_true
        g1 = 1 - y_true
        # logger.info(
        #     f"y_true shape: {y_true.shape}, y_pred: {y_pred.shape}, p0: {p0.shape}, p1: {p1.shape}, g0: {g0.shape}, g1: {g1.shape}")

        num = torch.sum(p0 * g0, (0, 2, 3, 4))
        den = num + alpha * torch.sum(p0 * g1, (0, 2, 3, 4)) + beta * torch.sum(p1 * g0, (0, 2, 3, 4))

        T = torch.sum(num / den)  # when summing over classes, T has dynamic range [0 Ncl]

        Ncl = float(y_true.shape[1])
        return Ncl - T


# (4) 这里是数据参数，需要修改

# --- water

In [None]:
import os
s3_path = "/dbfs/mnt/hli-imaging-sdrad-pdx/Whole_Body_Composition_v2/labeled_data_finalized_1/"
ids = [x.name for x in os.scandir(s3_path) if os.path.isdir(s3_path + x.name)]
print(len(ids))

900


In [None]:
male_val_id = ['BJ00000075', 'BJ00000060', 'BJ00000054', 'BJ00000050', 'BJ00000096']
female_val_id = ['BJ00000071', 'BJ00000150', 'BJ00000077', 'BJ00000033', 'BJ00000120']
test_ids = male_val_id + female_val_id

# male_test_id = ['BJ00000016', 'BJ00000125', 'BJ00000147', 'BJ00000044', 'BJ00000114']
# female_test_id = ['BJ00000100', 'BJ00000090', 'BJ00000110', 'BJ00000108', 'BJ00000002']
male_test_id = ['BJ00000016', 'BJ00000044', 'BJ00000114']
female_test_id = ['BJ00000100', 'BJ00000002']
test_ids_2 = male_test_id + female_test_id

In [None]:
# FAT  ->   VAT/ASAT

log_level = "INFO"
# log_level = 'DEBUG'
logger.remove()
handler_id = logger.add(sys.stderr, level=log_level)
epochs = 100
batch_size = 1
lr = 0.001

dim = [240, 256, 320]
slice_interval = [120, 360]
n_classes = 3
x_prefix='fat'
y_prefix='fat_label'
x_postfix= "npy" 
y_postfix= "npy"


data_path = '/dbfs/mnt/hli-imaging-sdrad-pdx/Whole_Body_Composition_v2/labeled_data_finalized_1'

model_path = './TEST/'
os.makedirs(model_path, exist_ok=True)
# x_prefix='water_288_256_320'
# y_prefix='water_label_288_256_320'


# train_ids = [user_id for user_id in os.listdir(data_path) if user_id != ".DS_Store" and (user_id not in test_ids) and user_id.startswith('BJ') and (user_id not in test_ids_2)]
filter_1 = [x for x in set(ids) if (('_' not in x) or (x.split('_')[1] not in set(test_ids) | set(test_ids_2))) and (x.startswith('BJ'))]
train_ids = list(set(filter_1) - set(test_ids) - set(test_ids_2))
logger.info(f"train ids: {len(train_ids)}, test ids: {len(test_ids)}")

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# model = UNet3d(in_channels=1, n_classes=3)


In [None]:
# water  ->   muscle

log_level = "INFO"
# log_level = 'DEBUG'
logger.remove()
handler_id = logger.add(sys.stderr, level=log_level)
epochs = 100
batch_size = 1
lr = 0.001
# slice_interval = [209, 489]
# slice_interval = [205, 493]
# dim = [240, 128, 176]
# dim = [240, 96, 112]

dim = [304, 256, 320]
slice_interval = [197, 501]


# dim = [288, 224+32, 288+32] # vnet 最大尺寸

# dim = [288, 224-16, 288-16-16] # [1, 0.8, 0.8] 4*16g
# dim = [288, 128, 176] # vnet 尺寸


# dim = [280, 200, 264] # unet 最大尺寸
n_classes = 2
# data_path = '/mnt/sdma/data/processed_image_data_1/original_nii_format_data'
data_path = '/dbfs/mnt/hli-imaging-sdrad-pdx/Whole_Body_Composition_v2/labeled_data_finalized_1'
# test_ids = ['BJ00000019', 'BJ00000124', 'BJ00000150', 'BJ00000036']

# model_path = "/mnt/sdma/data/medical_segmetation_log/model/water/"
model_path = './TEST/'
os.makedirs(model_path, exist_ok=True)
# x_prefix='water_288_256_320'
# y_prefix='water_label_288_256_320'
x_prefix='water_304_256_320'
y_prefix='water_label_304_256_320'
x_postfix= "npy" 
y_postfix= "npy"


# train_ids = [user_id for user_id in os.listdir(data_path) if user_id != ".DS_Store" and (user_id not in test_ids) and user_id.startswith('BJ') and (user_id not in test_ids_2)]
filter_1 = [x for x in set(ids) if (('_' not in x) or (x.split('_')[1] not in set(test_ids) | set(test_ids_2))) and (x.startswith('BJ'))]
train_ids = list(set(filter_1) - set(test_ids) - set(test_ids_2))
logger.info(f"train ids: {len(train_ids)}, test ids: {len(test_ids)}")

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# model = UNet3d(in_channels=1, n_classes=3)


2022-11-09 03:35:55.007 | INFO     | __main__:<cell line: 45>:45 - train ids: 135, test ids: 10


# (5) model training

In [None]:
trainset = DataGenerator(list_IDs=train_ids, 
                         file_path=data_path, 
                        num_classes=n_classes,
                         resize_shape=dim,
                         slice_interval=slice_interval,
                         x_prefix=x_prefix, 
                         y_prefix=y_prefix,
                         x_postfix=x_postfix, 
                         y_postfix=y_postfix
                         
                        )
testset = DataGenerator(list_IDs=test_ids, 
                        file_path=data_path,
                        num_classes=n_classes,
                         resize_shape=dim,
                         slice_interval=slice_interval,
                         x_prefix=x_prefix, 
                         y_prefix=y_prefix,
                         x_postfix=x_postfix, 
                         y_postfix=y_postfix)

train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size)

evalmetric = EvalMetric(dim, batch_size, n_classes)


In [None]:
model = VNet_Parallelism(in_channels=1, classes=n_classes)
# model_dict = model.load_state_dict(torch.load('./TEST/epoch5_trainloss_0.066_validacc_0.937.pth'))

# model = VNet(in_channels=1, out_channels=n_classes)
# model = model.to(device)

# model

2022-11-09 03:26:11.686 | INFO     | __main__:__init__:12 - begin initialize model struction


# --- 仅使用模型并行+单精度的VNET版本

In [None]:
def evaluate(model, test_loader):
    acc_list = []
    losses = 0
    model.eval()
    with torch.no_grad():
        for idx, x, y in test_loader:
            # x, y = x.to(device), y.to(device)
            # with autocast():
            logits = model(x.to(next(model.parameters()).device))
            logits = F.softmax(logits, dim=1)
            y = y.to(logits.device)
            loss = evalmetric.tversky_loss(y, logits)
            losses += loss
            acc = evalmetric.dice_coef_multilabel(y, logits)
            acc_list.append(acc)
        val_loss = losses / len(test_loader)
        return acc_list, sum(acc_list) / acc_list.__len__(), val_loss
# 暂时使用交叉熵损失函数测试模型

max_acc = 0
# optimizer = torch.optim.ASGD(model.parameters(), lr=lr)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
model.train()
for epoch in range(epochs):
    losses = 0
    start_time = time.time()
    for i, (idx, x, y) in enumerate(train_loader):
        # print(x.shape, y.shape)
        # for s in range(0, 240, 40):
        #     draw_multi(x[0][0][s], y[0][1][s], y[0][2][s])

        # x = x.to(device)
        # y = y.to(device)
        logits = model(x.to(next(model.parameters()).device))
        logits = F.softmax(logits, dim=1)
        y = y.to(logits.device)
        loss = evalmetric.tversky_loss(y, logits)
        origin_loss = loss
        optimizer.zero_grad()
        
        # 模型并行必须讲处理之后的loss值重新赋值给原来的loss，使用原来的loss去backword
        # with amp.scale_loss(loss, optimizer) as scaled_loss:
        #     # logger.info(f"type {type(scaled_loss)}, {scaled_loss.shape}, {scaled_loss}, {scaled_loss.device}")
        #     loss = scaled_loss   

            # scaled_loss.to(next(model.parameters()).device)
            # scaled_loss.backward()
        loss.backward()
#         torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        losses += loss.item()
        
        acc = evalmetric.dice_coef_multilabel(y, logits)
        
        # if hasattr(torch.cuda, 'empty_cache'):
        #     torch.cuda.empty_cache()
        if i % 1 == 0:
            logger.info(f"Epoch: {epoch}, idx:{idx[0]}, Batch[{i+1}/{len(train_loader)}],"
                        f" Train loss:{origin_loss:.3f}, Train acc:{acc: .3f}")
        
    #     break
    # break
    acc_list, mean_acc, val_loss = evaluate(model, test_loader)
    end_time = time.time()
    train_loss = losses / len(train_loader)
    if mean_acc > max_acc:
        torch.save(model.state_dict(), f"{model_path}epoch{epoch+1}_trainloss_{train_loss:.3f}_validacc_{mean_acc:.3f}.pth")
    logger.info(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Epoch time={(end_time - start_time):.3f}s, valid mean accuracy: {mean_acc:.3f}, valid loss: {val_loss}")


2022-11-03 06:44:53.185 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000107, Batch[1/130], Train loss:1.435, Train acc: 0.108
2022-11-03 06:45:06.727 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000132, Batch[2/130], Train loss:1.097, Train acc: 0.269
2022-11-03 06:45:20.267 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000039, Batch[3/130], Train loss:1.228, Train acc: 0.151
2022-11-03 06:45:33.807 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000103, Batch[4/130], Train loss:1.126, Train acc: 0.222
2022-11-03 06:45:47.394 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000137, Batch[5/130], Train loss:1.160, Train acc: 0.198
2022-11-03 06:46:00.932 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000027, Batch[6/130], Train loss:1.173, Train acc: 0.188
2022-11-03 06:46:14.477 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000029, Batch[7/130], Train loss:1.128, Train acc: 0.215
2022-1

In [None]:
def evaluate(model, test_loader):
    acc_list = []
    losses = 0
    model.eval()
    with torch.no_grad():
        for idx, x, y in test_loader:
            # x, y = x.to(device), y.to(device)
            # with autocast():
            logits = model(x.to(next(model.parameters()).device))
            logits = F.softmax(logits, dim=1)
            y = y.to(logits.device)
            loss = evalmetric.tversky_loss(y, logits)
            losses += loss
            acc = evalmetric.dice_coef_multilabel(y, logits)
            acc_list.append(acc)
        val_loss = losses / len(test_loader)
        return acc_list, sum(acc_list) / acc_list.__len__(), val_loss
# 暂时使用交叉熵损失函数测试模型

max_acc = 0
# optimizer = torch.optim.ASGD(model.parameters(), lr=lr)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
model.train()
for epoch in range(epochs):
    losses = 0
    start_time = time.time()
    for i, (idx, x, y) in enumerate(train_loader):
        # print(x.shape, y.shape)
        # for s in range(0, 240, 40):
        #     draw_multi(x[0][0][s], y[0][1][s], y[0][2][s])

        # x = x.to(device)
        # y = y.to(device)
        logits = model(x.to(next(model.parameters()).device))
        logits = F.softmax(logits, dim=1)
        y = y.to(logits.device)
        loss = evalmetric.tversky_loss(y, logits)
        origin_loss = loss
        optimizer.zero_grad()
        
        # 模型并行必须讲处理之后的loss值重新赋值给原来的loss，使用原来的loss去backword
        # with amp.scale_loss(loss, optimizer) as scaled_loss:
        #     # logger.info(f"type {type(scaled_loss)}, {scaled_loss.shape}, {scaled_loss}, {scaled_loss.device}")
        #     loss = scaled_loss   

            # scaled_loss.to(next(model.parameters()).device)
            # scaled_loss.backward()
        loss.backward()
#         torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        losses += loss.item()
        
        acc = evalmetric.dice_coef_multilabel(y, logits)
        
        # if hasattr(torch.cuda, 'empty_cache'):
        #     torch.cuda.empty_cache()
        if i % 1 == 0:
            logger.info(f"Epoch: {epoch}, idx:{idx[0]}, Batch[{i+1}/{len(train_loader)}],"
                        f" Train loss:{origin_loss:.3f}, Train acc:{acc: .3f}")
        
    #     break
    # break
    acc_list, mean_acc, val_loss = evaluate(model, test_loader)
    end_time = time.time()
    train_loss = losses / len(train_loader)
    if mean_acc > max_acc:
        torch.save(model.state_dict(), f"{model_path}epoch{epoch+1}_trainloss_{train_loss:.3f}_validacc_{mean_acc:.3f}.pth")
    logger.info(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Epoch time={(end_time - start_time):.3f}s, valid mean accuracy: {mean_acc:.3f}, valid loss: {val_loss}")


2022-11-08 06:36:13.012 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000001, Batch[1/135], Train loss:1.366, Train acc: 0.071
2022-11-08 06:36:27.310 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000055, Batch[2/135], Train loss:1.266, Train acc: 0.107
2022-11-08 06:36:41.587 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000006, Batch[3/135], Train loss:1.206, Train acc: 0.136
2022-11-08 06:36:55.910 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000019, Batch[4/135], Train loss:1.127, Train acc: 0.195
2022-11-08 06:37:10.230 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000065, Batch[5/135], Train loss:1.122, Train acc: 0.198
2022-11-08 06:37:24.526 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000038, Batch[6/135], Train loss:1.214, Train acc: 0.128
2022-11-08 06:37:38.816 | INFO     | __main__:<cell line: 25>:60 - Epoch: 0, idx:BJ00000141, Batch[7/135], Train loss:1.193, Train acc: 0.143
2022-1

In [None]:
%sh
ls ./TEST/

epoch10_trainloss_0.051_validacc_0.946.pth
epoch11_trainloss_0.052_validacc_0.944.pth
epoch12_trainloss_0.049_validacc_0.949.pth
epoch13_trainloss_0.048_validacc_0.944.pth
epoch14_trainloss_0.046_validacc_0.951.pth
epoch15_trainloss_0.047_validacc_0.941.pth
epoch16_trainloss_0.046_validacc_0.953.pth
epoch17_trainloss_0.056_validacc_0.940.pth
epoch18_trainloss_0.045_validacc_0.947.pth
epoch19_trainloss_0.044_validacc_0.954.pth
epoch1_trainloss_1.046_validacc_0.243.pth
epoch20_trainloss_0.043_validacc_0.951.pth
epoch21_trainloss_0.041_validacc_0.955.pth
epoch22_trainloss_0.041_validacc_0.956.pth
epoch23_trainloss_0.040_validacc_0.956.pth
epoch24_trainloss_0.040_validacc_0.951.pth
epoch25_trainloss_0.041_validacc_0.951.pth
epoch26_trainloss_0.040_validacc_0.951.pth
epoch27_trainloss_0.038_validacc_0.954.pth
epoch28_trainloss_0.038_validacc_0.958.pth
epoch29_trainloss_0.038_validacc_0.959.pth
epoch2_trainloss_0.382_validacc_0.668.pth
epoch30_trainloss_0.037_validacc_0.958.pth
epoch31_train

# (6) save weights to s3

In [None]:
%sh
rm ./TEST/cpu_epoch29_trainloss_0.038_validacc_0.959.pth

In [None]:
%sh
ls ./TEST/

cpu_epoch34_trainloss_0.037_validacc_0.959.pth
cpu_epoch35_trainloss_0.037_validacc_0.961.pth
epoch10_trainloss_0.051_validacc_0.946.pth
epoch11_trainloss_0.052_validacc_0.944.pth
epoch12_trainloss_0.049_validacc_0.949.pth
epoch13_trainloss_0.048_validacc_0.944.pth
epoch14_trainloss_0.046_validacc_0.951.pth
epoch15_trainloss_0.047_validacc_0.941.pth
epoch16_trainloss_0.046_validacc_0.953.pth
epoch17_trainloss_0.056_validacc_0.940.pth
epoch18_trainloss_0.045_validacc_0.947.pth
epoch19_trainloss_0.044_validacc_0.954.pth
epoch1_trainloss_1.046_validacc_0.243.pth
epoch20_trainloss_0.043_validacc_0.951.pth
epoch21_trainloss_0.041_validacc_0.955.pth
epoch22_trainloss_0.041_validacc_0.956.pth
epoch23_trainloss_0.040_validacc_0.956.pth
epoch24_trainloss_0.040_validacc_0.951.pth
epoch25_trainloss_0.041_validacc_0.951.pth
epoch26_trainloss_0.040_validacc_0.951.pth
epoch27_trainloss_0.038_validacc_0.954.pth
epoch28_trainloss_0.038_validacc_0.958.pth
epoch29_trainloss_0.038_validacc_0.959.pth
epoc

In [None]:
model_dict = model.load_state_dict(torch.load('./TEST/cpu_epoch35_trainloss_0.037_validacc_0.961.pth'))
# torch.save(model.cpu().state_dict(), './TEST/cpu_epoch34_trainloss_0.037_validacc_0.959.pth')

In [None]:
%sh
/databricks/python/bin/aws s3 cp --recursive /databricks/driver/TEST s3://hli-imaging-sdrad-pdx/Whole_Body_Composition_v2/trained_model_water/11072022_vnet_135_cases_304_256_320/ --sse --acl bucket-owner-full-control

Completed 256.0 KiB/6.5 GiB (1.8 MiB/s) with 38 file(s) remaining
Completed 512.0 KiB/6.5 GiB (3.5 MiB/s) with 38 file(s) remaining
Completed 768.0 KiB/6.5 GiB (5.1 MiB/s) with 38 file(s) remaining
Completed 1.0 MiB/6.5 GiB (6.6 MiB/s) with 38 file(s) remaining  
Completed 1.2 MiB/6.5 GiB (8.2 MiB/s) with 38 file(s) remaining  
Completed 1.5 MiB/6.5 GiB (9.8 MiB/s) with 38 file(s) remaining  
Completed 1.8 MiB/6.5 GiB (11.5 MiB/s) with 38 file(s) remaining 
Completed 2.0 MiB/6.5 GiB (13.0 MiB/s) with 38 file(s) remaining 
Completed 2.2 MiB/6.5 GiB (14.4 MiB/s) with 38 file(s) remaining 
Completed 2.5 MiB/6.5 GiB (15.9 MiB/s) with 38 file(s) remaining 
Completed 2.8 MiB/6.5 GiB (17.4 MiB/s) with 38 file(s) remaining 
Completed 3.0 MiB/6.5 GiB (18.9 MiB/s) with 38 file(s) remaining 
Completed 3.2 MiB/6.5 GiB (20.4 MiB/s) with 38 file(s) remaining 
Completed 3.5 MiB/6.5 GiB (21.7 MiB/s) with 38 file(s) remaining 
Completed 3.8 MiB/6.5 GiB (23.2 MiB/s) with 38 file(s) remaining 
Completed 

In [None]:
%sh
/databricks/python/bin/aws s3 cp /databricks/driver/TEST/weights.024-0.15-0.9238.hdf5 s3://hli-imaging-sdrad-pdx/Whole_Body_Composition_v2/trained_model_fat/11022022_vnet_130_cases_240_256_320/409_10_slicing_240_176_224_0_32_weights.024-0.15-0.9238.hdf5 --sse --acl bucket-owner-full-control

--- copy weights from s3

In [None]:
%sh
/databricks/python/bin/aws s3 cp s3://hli-imaging-sdrad-pdx/Whole_Body_Composition_v2/trained_model_water/11022022_vnet_130_cases_288_256_320/epoch5_trainloss_0.066_validacc_0.937.pth /databricks/driver/TEST/epoch5_trainloss_0.066_validacc_0.937.pth --sse --acl bucket-owner-full-control

Completed 256.0 KiB/174.0 MiB (1.2 MiB/s) with 1 file(s) remaining
Completed 512.0 KiB/174.0 MiB (2.3 MiB/s) with 1 file(s) remaining
Completed 768.0 KiB/174.0 MiB (3.5 MiB/s) with 1 file(s) remaining
Completed 1.0 MiB/174.0 MiB (4.6 MiB/s) with 1 file(s) remaining  
Completed 1.2 MiB/174.0 MiB (5.7 MiB/s) with 1 file(s) remaining  
Completed 1.5 MiB/174.0 MiB (6.8 MiB/s) with 1 file(s) remaining  
Completed 1.8 MiB/174.0 MiB (7.9 MiB/s) with 1 file(s) remaining  
Completed 2.0 MiB/174.0 MiB (8.9 MiB/s) with 1 file(s) remaining  
Completed 2.2 MiB/174.0 MiB (10.0 MiB/s) with 1 file(s) remaining 
Completed 2.5 MiB/174.0 MiB (10.9 MiB/s) with 1 file(s) remaining 
Completed 2.8 MiB/174.0 MiB (12.0 MiB/s) with 1 file(s) remaining 
Completed 3.0 MiB/174.0 MiB (12.8 MiB/s) with 1 file(s) remaining 
Completed 3.2 MiB/174.0 MiB (13.9 MiB/s) with 1 file(s) remaining 
Completed 3.5 MiB/174.0 MiB (14.8 MiB/s) with 1 file(s) remaining 
Completed 3.8 MiB/174.0 MiB (15.7 MiB/s) with 1 file(s) remain

# (7) predict cases

In [None]:
%sh
ls ./TEST/

cpu_epoch34_trainloss_0.037_validacc_0.959.pth
cpu_epoch35_trainloss_0.037_validacc_0.961.pth
epoch10_trainloss_0.051_validacc_0.946.pth
epoch11_trainloss_0.052_validacc_0.944.pth
epoch12_trainloss_0.049_validacc_0.949.pth
epoch13_trainloss_0.048_validacc_0.944.pth
epoch14_trainloss_0.046_validacc_0.951.pth
epoch15_trainloss_0.047_validacc_0.941.pth
epoch16_trainloss_0.046_validacc_0.953.pth
epoch17_trainloss_0.056_validacc_0.940.pth
epoch18_trainloss_0.045_validacc_0.947.pth
epoch19_trainloss_0.044_validacc_0.954.pth
epoch1_trainloss_1.046_validacc_0.243.pth
epoch20_trainloss_0.043_validacc_0.951.pth
epoch21_trainloss_0.041_validacc_0.955.pth
epoch22_trainloss_0.041_validacc_0.956.pth
epoch23_trainloss_0.040_validacc_0.956.pth
epoch24_trainloss_0.040_validacc_0.951.pth
epoch25_trainloss_0.041_validacc_0.951.pth
epoch26_trainloss_0.040_validacc_0.951.pth
epoch27_trainloss_0.038_validacc_0.954.pth
epoch28_trainloss_0.038_validacc_0.958.pth
epoch29_trainloss_0.038_validacc_0.959.pth
epoc

In [None]:
batch_size = 1
n_classes = 2
dim = [304, 256, 320]
slice_interval = [197, 501]

model = VNet_Parallelism(in_channels=1, classes=n_classes)
model_dict = model.load_state_dict(torch.load('./TEST/cpu_epoch35_trainloss_0.037_validacc_0.961.pth'))

2022-11-09 03:43:07.360 | INFO     | __main__:__init__:10 - begin initialize model struction


In [None]:
from termcolor import *

smooth = 1.
def dice_coef(y_true, y_pred):
    y_true_f = np.ndarray.flatten(y_true)
    y_pred_f = np.ndarray.flatten(y_pred)
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)

class_map = {0: 0, 1: 1}
def one_hot_encoding(category):
    class_mask = np.zeros((304, 256, 320, 2), dtype=np.float32)
    class_mask[:,:,:,0] = 1
    for mask_value in class_map:
        empty_holder = np.zeros((2))
        empty_holder[class_map[mask_value]] = 1
        class_mask[category == mask_value] = empty_holder
    return class_mask

def numerical_analysis(y_true, y_pred):
    y_pred_f = np.ndarray.flatten(y_pred[:,:,:,1])
    y_true_f = np.ndarray.flatten(y_true[:,:,:,1])
    pred = np.sum(y_pred[:,:,:,1])
    true = np.sum(y_true[:,:,:,1])
    inter = np.sum(y_pred_f*y_true_f)
    print('prediction: ', pred, ' true: ', true, 'intersection: ', inter, 'dice: ', 2*inter/(pred+true))
    print(colored('--- false negative: ', 'red'), true-inter, colored('--- false positive: ', 'red'), pred-inter, '\n')


In [None]:
male_val_id = ['BJ00000075', 'BJ00000060', 'BJ00000054', 'BJ00000050', 'BJ00000096']
female_val_id = ['BJ00000071', 'BJ00000150', 'BJ00000077', 'BJ00000033', 'BJ00000120']
test_ids = male_val_id + female_val_id

# male_test_id = ['BJ00000016', 'BJ00000125', 'BJ00000147', 'BJ00000044', 'BJ00000114']
# female_test_id = ['BJ00000100', 'BJ00000090', 'BJ00000110', 'BJ00000108', 'BJ00000002']
male_test_id = ['BJ00000016', 'BJ00000044', 'BJ00000114']
female_test_id = ['BJ00000100', 'BJ00000002']
test_ids_2 = male_test_id + female_test_id

test_ids_all = test_ids + test_ids_2

In [None]:
testset = DataGenerator(list_IDs=test_ids_all, 
                        file_path=data_path,
                        num_classes=n_classes,
                         resize_shape=dim,
                         slice_interval=slice_interval,
                         x_prefix=x_prefix, 
                         y_prefix=y_prefix,
                         x_postfix=x_postfix, 
                         y_postfix=y_postfix)
test_loader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size)
# evalmetric = EvalMetric(dim, batch_size, n_classes)

# --- run all cases

# ------ cpu

In [None]:
# cpu_epoch35_trainloss_0.037_validacc_0.961.pth
model.eval()
with torch.no_grad():
    for idx, x, y in test_loader:
#       if idx[0] == 'BJ00000060':
        print(idx)
        print('[1] shape of fat and fat_label: ', x.shape, y.shape)
        # x, y = x.to(device), y.to(device)
        with autocast():
            # Step 1, predict
            logits = model(x.to(next(model.parameters()).device))
            logits = F.softmax(logits, dim=1)
            print('[2] shape of logits: ', logits.shape)
            
            # Step 2, argmax
            fat_argmax = np.argmax(logits.cpu(), axis=1)
            print('[3] argmax of fat_prediction: ', fat_argmax.shape)
            fat_argmax = np.reshape(fat_argmax, dim)
            print('[4] reshape of fat_argmax: ', fat_argmax.shape)
            
            # Step 3, one-hot encoding
            fat_argmax = zoom(fat_argmax, (304/fat_argmax.shape[0], 256/fat_argmax.shape[1], 320/fat_argmax.shape[2]), order=0, mode='nearest') # resize to [288,260,320]
            prediction_label_origin = one_hot_encoding(fat_argmax)
            print('[5] one-hot encoding of prediction: ', prediction_label_origin.shape)
            input_y = np.load(data_path + '/' + idx[0] + '/water_label.npy')
            slice_y = input_y[slice_interval[0]:slice_interval[1], 2:-2, :]
            img_fat_label = one_hot_encoding(slice_y)
            print('[6] one-hot encoding of origin: ', img_fat_label.shape)
            
            # Step 6, output dice for each category
            print('[7] dice of muscle: ', dice_coef(img_fat_label[:,:,:,1], prediction_label_origin[:,:,:,1]))
            print('[8] numerical analysis:')
            numerical_analysis(img_fat_label, prediction_label_origin)


('BJ00000075',)
[1] shape of fat and fat_label:  torch.Size([1, 1, 304, 256, 320]) torch.Size([1, 2, 304, 256, 320])
[2] shape of logits:  torch.Size([1, 2, 304, 256, 320])
[3] argmax of fat_prediction:  torch.Size([1, 304, 256, 320])
[4] reshape of fat_argmax:  torch.Size([304, 256, 320])
[5] one-hot encoding of prediction:  (304, 256, 320, 2)
[6] one-hot encoding of origin:  (304, 256, 320, 2)
[7] dice of muscle:  0.9701602630808448
[8] numerical analysis:
prediction:  2276439.0  true:  2235329.0 intersection:  2188569.0 dice:  0.9701602564670878
[31m--- false negative: [0m 46760.0 [31m--- false positive: [0m 87870.0 

('BJ00000060',)
[1] shape of fat and fat_label:  torch.Size([1, 1, 304, 256, 320]) torch.Size([1, 2, 304, 256, 320])
[2] shape of logits:  torch.Size([1, 2, 304, 256, 320])
[3] argmax of fat_prediction:  torch.Size([1, 304, 256, 320])
[4] reshape of fat_argmax:  torch.Size([304, 256, 320])
[5] one-hot encoding of prediction:  (304, 256, 320, 2)
[6] one-hot encoding

In [None]:
# cpu_epoch35_trainloss_0.037_validacc_0.961.pth
model.eval()
with torch.no_grad():
    for idx, x, y in test_loader:
#       if idx[0] == 'BJ00000060':
        print(idx)
        print('[1] shape of fat and fat_label: ', x.shape, y.shape)
        # x, y = x.to(device), y.to(device)
        with autocast():
            # Step 1, predict
            logits = model(x.to(next(model.parameters()).device))
            logits = F.softmax(logits, dim=1)
            print('[2] shape of logits: ', logits.shape)
            
            # Step 2, argmax
            fat_argmax = np.argmax(logits.cpu(), axis=1)
            print('[3] argmax of fat_prediction: ', fat_argmax.shape)
            fat_argmax = np.reshape(fat_argmax, dim)
            print('[4] reshape of fat_argmax: ', fat_argmax.shape)
            
            # Step 3, one-hot encoding
            fat_argmax = zoom(fat_argmax, (304/fat_argmax.shape[0], 256/fat_argmax.shape[1], 320/fat_argmax.shape[2]), order=0, mode='nearest') # resize to [288,260,320]
            prediction_label_origin = one_hot_encoding(fat_argmax)
            print('[5] one-hot encoding of prediction: ', prediction_label_origin.shape)
            input_y = np.load(data_path + '/' + idx[0] + '/water_label.npy')
            slice_y = input_y[slice_interval[0]:slice_interval[1], 2:-2, :]
            img_fat_label = one_hot_encoding(slice_y)
            print('[6] one-hot encoding of origin: ', img_fat_label.shape)
            
            # Step 6, output dice for each category
            print('[7] dice of muscle: ', dice_coef(img_fat_label[:,:,:,1], prediction_label_origin[:,:,:,1]))
            print('[8] numerical analysis:')
            numerical_analysis(img_fat_label, prediction_label_origin)


('BJ00000075',)
[1] shape of fat and fat_label:  torch.Size([1, 1, 304, 256, 320]) torch.Size([1, 2, 304, 256, 320])
[2] shape of logits:  torch.Size([1, 2, 304, 256, 320])
[3] argmax of fat_prediction:  torch.Size([1, 304, 256, 320])
[4] reshape of fat_argmax:  torch.Size([304, 256, 320])
[5] one-hot encoding of prediction:  (304, 256, 320, 2)
[6] one-hot encoding of origin:  (304, 256, 320, 2)
[7] dice of muscle:  0.9701602630808448
[8] numerical analysis:
prediction:  2276439.0  true:  2235329.0 intersection:  2188569.0 dice:  0.9701602564670878
[31m--- false negative: [0m 46760.0 [31m--- false positive: [0m 87870.0 

('BJ00000060',)
[1] shape of fat and fat_label:  torch.Size([1, 1, 304, 256, 320]) torch.Size([1, 2, 304, 256, 320])
[2] shape of logits:  torch.Size([1, 2, 304, 256, 320])
[3] argmax of fat_prediction:  torch.Size([1, 304, 256, 320])
[4] reshape of fat_argmax:  torch.Size([304, 256, 320])
[5] one-hot encoding of prediction:  (304, 256, 320, 2)
[6] one-hot encoding

In [None]:
# cpu_epoch34_trainloss_0.037_validacc_0.959.pth
model.eval()
with torch.no_grad():
    for idx, x, y in test_loader:
#       if idx[0] == 'BJ00000060':
        print(idx)
        print('[1] shape of fat and fat_label: ', x.shape, y.shape)
        # x, y = x.to(device), y.to(device)
        with autocast():
            # Step 1, predict
            logits = model(x.to(next(model.parameters()).device))
            logits = F.softmax(logits, dim=1)
            print('[2] shape of logits: ', logits.shape)
            
            # Step 2, argmax
            fat_argmax = np.argmax(logits.cpu(), axis=1)
            print('[3] argmax of fat_prediction: ', fat_argmax.shape)
            fat_argmax = np.reshape(fat_argmax, dim)
            print('[4] reshape of fat_argmax: ', fat_argmax.shape)
            
            # Step 3, one-hot encoding
            fat_argmax = zoom(fat_argmax, (304/fat_argmax.shape[0], 256/fat_argmax.shape[1], 320/fat_argmax.shape[2]), order=0, mode='nearest') # resize to [288,260,320]
            prediction_label_origin = one_hot_encoding(fat_argmax)
            print('[5] one-hot encoding of prediction: ', prediction_label_origin.shape)
            input_y = np.load(data_path + '/' + idx[0] + '/water_label.npy')
            slice_y = input_y[slice_interval[0]:slice_interval[1], 2:-2, :]
            img_fat_label = one_hot_encoding(slice_y)
            print('[6] one-hot encoding of origin: ', img_fat_label.shape)
            
            # Step 6, output dice for each category
            print('[7] dice of muscle: ', dice_coef(img_fat_label[:,:,:,1], prediction_label_origin[:,:,:,1]))
            print('[8] numerical analysis:')
            numerical_analysis(img_fat_label, prediction_label_origin)


('BJ00000075',)
[1] shape of fat and fat_label:  torch.Size([1, 1, 304, 256, 320]) torch.Size([1, 2, 304, 256, 320])
[2] shape of logits:  torch.Size([1, 2, 304, 256, 320])
[3] argmax of fat_prediction:  torch.Size([1, 304, 256, 320])
[4] reshape of fat_argmax:  torch.Size([304, 256, 320])
[5] one-hot encoding of prediction:  (304, 256, 320, 2)
[6] one-hot encoding of origin:  (304, 256, 320, 2)
[7] dice of muscle:  0.9677196510816715
[8] numerical analysis:
prediction:  2317254.0  true:  2235329.0 intersection:  2202812.0 dice:  0.9677196439911145
[31m--- false negative: [0m 32517.0 [31m--- false positive: [0m 114442.0 

('BJ00000060',)
[1] shape of fat and fat_label:  torch.Size([1, 1, 304, 256, 320]) torch.Size([1, 2, 304, 256, 320])
[2] shape of logits:  torch.Size([1, 2, 304, 256, 320])
[3] argmax of fat_prediction:  torch.Size([1, 304, 256, 320])
[4] reshape of fat_argmax:  torch.Size([304, 256, 320])
[5] one-hot encoding of prediction:  (304, 256, 320, 2)
[6] one-hot encodin

# ------ gpu

In [None]:
# epoch35_trainloss_0.037_validacc_0.961.pth
model.eval()
with torch.no_grad():
    for idx, x, y in test_loader:
#       if idx[0] == 'BJ00000060':
        print(idx)
        print('[1] shape of fat and fat_label: ', x.shape, y.shape)
        # x, y = x.to(device), y.to(device)
        with autocast():
            # Step 1, predict
            logits = model(x.to(next(model.parameters()).device))
            logits = F.softmax(logits, dim=1)
            print('[2] shape of logits: ', logits.shape)
            
            # Step 2, argmax
            fat_argmax = np.argmax(logits.cpu(), axis=1)
            print('[3] argmax of fat_prediction: ', fat_argmax.shape)
            fat_argmax = np.reshape(fat_argmax, dim)
            print('[4] reshape of fat_argmax: ', fat_argmax.shape)
            
            # Step 3, one-hot encoding
            fat_argmax = zoom(fat_argmax, (304/fat_argmax.shape[0], 256/fat_argmax.shape[1], 320/fat_argmax.shape[2]), order=0, mode='nearest') # resize to [288,260,320]
            prediction_label_origin = one_hot_encoding(fat_argmax)
            print('[5] one-hot encoding of prediction: ', prediction_label_origin.shape)
            input_y = np.load(data_path + '/' + idx[0] + '/water_label.npy')
            slice_y = input_y[slice_interval[0]:slice_interval[1], 2:-2, :]
            img_fat_label = one_hot_encoding(slice_y)
            print('[6] one-hot encoding of origin: ', img_fat_label.shape)
            
            # Step 6, output dice for each category
            print('[7] dice of muscle: ', dice_coef(img_fat_label[:,:,:,1], prediction_label_origin[:,:,:,1]))
            print('[8] numerical analysis:')
            numerical_analysis(img_fat_label, prediction_label_origin)


('BJ00000075',)
[1] shape of fat and fat_label:  torch.Size([1, 1, 304, 256, 320]) torch.Size([1, 2, 304, 256, 320])
[2] shape of logits:  torch.Size([1, 2, 304, 256, 320])
[3] argmax of fat_prediction:  torch.Size([1, 304, 256, 320])
[4] reshape of fat_argmax:  torch.Size([304, 256, 320])
[5] one-hot encoding of prediction:  (304, 256, 320, 2)
[6] one-hot encoding of origin:  (304, 256, 320, 2)
[7] dice of muscle:  0.7073269423695667
[8] numerical analysis:
prediction:  1585377.0  true:  2235329.0 intersection:  1351244.0 dice:  0.7073268657677403
[31m--- false negative: [0m 884085.0 [31m--- false positive: [0m 234133.0 

('BJ00000060',)
[1] shape of fat and fat_label:  torch.Size([1, 1, 304, 256, 320]) torch.Size([1, 2, 304, 256, 320])
[2] shape of logits:  torch.Size([1, 2, 304, 256, 320])
[3] argmax of fat_prediction:  torch.Size([1, 304, 256, 320])
[4] reshape of fat_argmax:  torch.Size([304, 256, 320])
[5] one-hot encoding of prediction:  (304, 256, 320, 2)
[6] one-hot encodi

# --- run single case

In [None]:
pid = None
prediction_label_origin = None
img_fat_label = None

model.eval()
with torch.no_grad():
    for idx, x, y in test_loader:
      if idx[0] == 'BJ00000075':
        print(idx)
        pid = idx[0]
        print('[1] shape of fat and fat_label: ', x.shape, y.shape)
        # x, y = x.to(device), y.to(device)
        with autocast():
            # Step 1, predict
            logits = model(x.to(next(model.parameters()).device))
            logits = F.softmax(logits, dim=1)
            print('[2] shape of logits: ', logits.shape)
            
            # Step 2, argmax
            fat_argmax = np.argmax(logits.cpu(), axis=1)
            print('[3] argmax of fat_prediction: ', fat_argmax.shape)
            fat_argmax = np.reshape(fat_argmax, dim)
            print('[4] reshape of fat_argmax: ', fat_argmax.shape)
            
            # Step 3, one-hot encoding
            fat_argmax = zoom(fat_argmax, (288/fat_argmax.shape[0], 260/fat_argmax.shape[1], 320/fat_argmax.shape[2]), order=0, mode='nearest') # resize to [288,260,320]
            prediction_label_origin = one_hot_encoding(fat_argmax)
            print('[5] one-hot encoding of prediction: ', prediction_label_origin.shape)
            input_y = np.load(data_path + '/' + idx[0] + '/water_label.npy')
            slice_y = input_y[slice_interval[0]:slice_interval[1], :, :]
            img_fat_label = one_hot_encoding(slice_y)
            print('[6] one-hot encoding of origin: ', img_fat_label.shape)
            
            # Step 6, output dice for each category
            print('[7] dice of muscle: ', dice_coef(img_fat_label[:,:,:,1], prediction_label_origin[:,:,:,1]))
            print('[8] numerical analysis:')
            numerical_analysis(img_fat_label, prediction_label_origin)


('BJ00000075',)
[1] shape of fat and fat_label:  torch.Size([1, 1, 288, 208, 256]) torch.Size([1, 2, 288, 208, 256])
[2] shape of logits:  torch.Size([1, 2, 288, 208, 256])
[3] argmax of fat_prediction:  torch.Size([1, 288, 208, 256])
[4] reshape of fat_argmax:  torch.Size([288, 208, 256])
[5] one-hot encoding of prediction:  (288, 260, 320, 2)
[6] one-hot encoding of origin:  (288, 260, 320, 2)
[7] dice of muscle:  0.9105527919746781
[8] numerical analysis:
prediction:  2440769.0  true:  2235329.0 intersection:  2128917.0 dice:  0.910552772846078
[31m--- false negative: [0m 106412.0 [31m--- false positive: [0m 311852.0 

