In [8]:
##先写数据增强的过程

import cv2
import numpy as np
from albumentations import Compose as Compose_albu
from albumentations import LongestMaxSize, PadIfNeeded


def to_numpy(data):
    image, label = data['image'], data['label']
    data['image'] = np.array(image)
    if data['label'] is not None:
        data['label'] = np.array(label)
    return data


class MedicalTransform:

    def __init__(self, output_size):
        if isinstance(output_size, (tuple, list)):
            self._output_size = output_size
        else:
            self._output_size = (output_size, output_size)

    
    def __call__(self, data):
        data = to_numpy(data)
        img, label = data['image'], data['label']
        is_3d = True if img.shape == 4 else False
        max_size = max(self._output_size[0], self._output_size[1])
        task = [
            LongestMaxSize(max_size, p=1),
            PadIfNeeded(self._output_size[0], self._output_size[1])
        ]

        aug = Compose_albu(task)
        if not is_3d:
            aug_data = aug(image=img, mask=label)
            data['image'], data['label'] = aug_data['image'], aug_data['mask']

        else:
            keys = {}
            targets = {}
            for i in range(1, img.shape[2]):
                keys.update({f'image{i}': 'image'})
                keys.update({f'mask{i}': 'mask'})
                targets.update({f'image{i}': img[:, :, i]})
                targets.update({f'mask{i}': label[:, :, i]})
            aug.add_targets(keys)
            
            targets.update({'image': img[:, :, 0]})
            targets.update({'mask': label[:, :, 0]})
            
            aug_data = aug(**targets)
            imgs = [aug_data['image']]
            labels = [aug_data['mask']]
            
            for i in range(1, img.shape[2]):
                imgs.append(aug_data[f'image{i}'])
                labels.append(aug_data[f'mask{i}'])
            
            img = np.stack(imgs, axis=-1)
            label = np.stack(labels, axis=-1)
            data['image'] = img
            data['label'] = label
        
        return data


In [9]:
##网络定义
import torch.nn as nn
from torchvision import models


class DenseUNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=3):
        super(DenseUNet, self).__init__()
        densenet = models.densenet161(pretrained=True)
        backbone = list(list(densenet.children())[0].children())
        
        if in_ch != 3:
            backbone[0] = nn.Conv2d(in_ch, 96, kernel_size=7, stride=2, padding=3, bias=False)
        
        self.conv1 = nn.Sequential(*backbone[:3])
        self.mp = backbone[3]
        self.denseblock1 = backbone[4]
        self.transition1 = backbone[5]
        self.denseblock2 = backbone[6]
        self.transition2 = backbone[7]
        self.denseblock3 = backbone[8]
        self.transition3 = backbone[9]
        self.denseblock4 = backbone[10]
        self.bn = backbone[11]
        self.up1 = _Up(x1_ch=2208, x2_ch=2112, out_ch=768)
        self.up2 = _Up(x1_ch=768, x2_ch=768, out_ch=384)
        self.up3 = _Up(x1_ch=384, x2_ch=384, out_ch=96)
        self.up4 = _Up(x1_ch=96, x2_ch=96, out_ch=96)
        self.up5 = nn.Sequential(
            _Interpolate(),
            nn.BatchNorm2d(num_features=96),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, padding=1)
        )
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=out_ch, kernel_size=1)
        
        self.up1_conv = nn.Conv2d(in_channels=768, out_channels=out_ch, kernel_size=1)
        self.up2_conv = nn.Conv2d(in_channels=384, out_channels=out_ch, kernel_size=1)
        self.up3_conv = nn.Conv2d(in_channels=96, out_channels=out_ch, kernel_size=1)
        self.up4_conv = nn.Conv2d(in_channels=96, out_channels=out_ch, kernel_size=1)
    
    def forward(self, x):
        x = self.conv1(x)
        x_ = self.mp(x)
        x1 = self.denseblock1(x_)
        x1t = self.transition1(x1)
        x2 = self.denseblock2(x1t)
        x2t = self.transition2(x2)
        x3 = self.denseblock3(x2t)
        x3t = self.transition3(x3)
        x4 = self.denseblock4(x3t)
        x4 = self.bn(x4)
        x5 = self.up1(x4, x3)
        x6 = self.up2(x5, x2)
        x7 = self.up3(x6, x1)
        x8 = self.up4(x7, x)
        feat = self.up5(x8)
        cls = self.conv2(feat)
        
        up1_cls = self.up1_conv(x5)
        up2_cls = self.up2_conv(x6)
        up3_cls = self.up3_conv(x7)
        up4_cls = self.up4_conv(x8)
        
        return {'output': cls, 'up1_cls': up1_cls, 'up2_cls': up2_cls, 'up3_cls': up3_cls, 'up4_cls': up4_cls}


class _Interpolate(nn.Module):
    def __init__(self, scale_factor=2, mode='bilinear', align_corners=True):
        super(_Interpolate, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.align_corners = align_corners
    
    def forward(self, x):
        x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode,
                                      align_corners=self.align_corners)
        return x


class _Up(nn.Module):
    def __init__(self, x1_ch, x2_ch, out_ch):
        super(_Up, self).__init__()
        self.up = _Interpolate()
        self.conv1x1 = nn.Conv2d(in_channels=x2_ch, out_channels=x1_ch, kernel_size=1)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=x1_ch, out_channels=out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_features=out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        x2 = self.conv1x1(x2)
        x = x1 + x2
        x = self.conv(x)
        return x


In [10]:
## 模型加载
import torch
from pathlib2 import Path

def save(epoch, net, optimizer, root):
    torch.save({
        'epoch': epoch,
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, root)

def _key_exist(data, cp, key):
    return key in data and data[key] and key in cp and cp[key]

def load_params(data, cp_file, device='cpu'):

    cp_file = Path(cp_file)
    assert cp_file.exists()

    cp = torch.load(str(cp_file), map_location=device)
    if _key_exist(data, cp, key='net'):
        data['net'].load_state_dict(cp['net'])

    if _key_exist(data, cp, key='optimizer'):
        data['optimizer'].load_state_dict(cp['optimizer'])
        for state in data['optimizer'].state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    if _key_exist(data, cp, key='epoch'):
        data['epoch'] = cp['epoch']

    return data
    

In [11]:
# 编写可视化
import matplotlib.pyplot as plt
import torch

def numpy_to_plt(img):
    return img.transpose((1, 2, 0))

def imshow(title, imgs, shape=None, subtitle=None, cmap=None, transpose=False, pause=0.001, pltshow=True):
    if type(imgs) is tuple:
        num = len(imgs)
        if shape is not None:
            assert shape[0] * shape[1] == num
        else:
            shape = (1, num)
        
        if type(subtitle) is not tuple:
            subtitle = (subtitle,) * num
        else:
            assert len(subtitle) == num
        
        if type(cmap) is not tuple:
            cmap = (cmap,) * num
        else:
            assert len(cmap) == num
        
        fig = plt.figure(num=title, figsize=(shape[1] * 3, shape[0] * 3 + 0.5))
        fig.clf()
        fig.suptitle(title)
        
        fig.subplots(shape[0], shape[1], sharex=True, sharey=True)
        axes = fig.get_axes()
        
        for i in range(shape[0]):
            for j in range(shape[1]):
                idx = i * shape[1] + j
                axes[idx].set_title(subtitle[idx])
                
                cm = cmap[idx]
                img = imgs[idx]
                if cmap[idx] is None and len(img.shape) == 3:
                    if img.shape[0] == 1 or len(img.shape) == 2:
                        cm = 'gray'
                        if len(img.shape) == 3 and img.shape[0] == 1:
                            img = img.reshape((img.shape[1], img.shape[2]))
                    elif img.shape[0] == 3:
                        img = numpy_to_plt(img)
                axes[idx].imshow(img, cm)
    
    else:
        if transpose:
            imgs = numpy_to_plt(imgs)
        plt.figure(num=title)
        plt.suptitle(title)
        plt.title(subtitle)
        plt.imshow(imgs, cmap)
    
    if pltshow:
        plt.ion()
        plt.show()
        plt.pause(pause)
    
    return plt.gcf()



In [12]:
# 数据处理dataset
import torch.nn as nn
from torch.utils import data
from pathlib2 import Path
from albumentations import Compose, PadIfNeeded, Resize
class TestSet(data.Dataset):

    def __init__(self, root, stack_num=1, img_size=(512, 512), transform=None):
        self._root = Path(root)
        self._stack_num = stack_num

        self._img_size = img_size
        self._transform = transform

        self._get_data()

        self._img_channels = self.__getitem__(0)['image'].shape[0]

    def _get_data(self, extentions=['*.jpg', '*.png', '*.npy']):
        self._imgs = []

        for extention in extentions:
            imgs = list(self._root.glob(extention))
            self._imgs += imgs

    def get_stack_num(self, idx):
        data_path = self._imgs[idx]

        if '.npy' in data_path.parts[-1]:
            img = self.__npydata(data_path)
            imgs = [img]*self._stack_num
            imgs = np.stack(imgs, axis=2)
            data = {'image': imgs, 'label': None}

            return data
        else:
            return None

    def __getitem__(self, idx):
        data = self.get_stack_num(idx)
        data = self._transform(data)
        data = self._default_transform(data)

        return data

    def __len__(self):
        return len(self._imgs)

    def __npydata(self, data_path):
        return np.load(str(data_path))

    def __imgdata(self, data_path):
        return cv2.cvtColor(cv2.imread(str(data_path)), cv2.COLOR_BGR2RGB)

    def vis_transform(self, data):
        cmap = [[0, 0, 0], [0, 255, 0], [0, 0, 255]]
        cmap = np.array(cmap, dtype=np.int)
        if 'image' in data.keys() and data['image'] is not None:
            imgs = data['image']
            if type(imgs).__module__ != np.__name__:
                imgs = imgs.cpu().detach().numpy()
            data['image'] = imgs
        
        if 'label' in data.keys() and data['label'] is not None and data['label'].shape[-1] != 0:
            labels = data['label']
            if type(labels).__module__ != np.__name__:
                labels = labels.cpu().detach().numpy()
            labels = cmap[labels]
            labels = labels.transpose((0, 3, 1, 2))
            labels = labels / 255
            data['label'] = labels
        
        if 'predict' in data.keys() and data['predict'] is not None:
            preds = data['predict']
            if type(preds).__module__ != np.__name__:
                preds = preds.cpu().detach().numpy()
            # print(preds.shape)
            if preds.shape[1] == self.num_classes:
                preds = preds.argmax(axis=1)
            preds = cmap[preds]
            preds = preds.transpose((0, 3, 1, 2))
            preds = preds / 255
            data['predict'] = preds
        
        return data


    def _resize(self, data):
        data = to_numpy(data)
        img, label = data['image'], data['label']
        num = max(img.shape[0], img.shape[1])

        aug = Compose_albu([
            PadIfNeeded(min_height=num, min_width=num, border_mode=cv2.BORDER_CONSTANT, p=1),
            Resize(height=self._img_size[0], width=self._img_size[1], p=1)
        ])

        data = aug(img=img, mask=label)
        img, label = data['image'], data['mask']

        data['image'] = img
        data['label'] = label
        return data

    def _default_transform(self, data):
        if (data['image'].shape[0], data['image'].shape[1]) != self._img_size:
            data = self._resize(data)
        
        image, label = data['image'], data['label']
        
        image = image.astype(np.float32)
        image = image.transpose((2, 0, 1))
        image = torch.from_numpy(image)
        data['image'] = image
        data['label'] = torch.Tensor()
        
        return data

    @property
    def img_channels(self):
        return self._img_channels

    @property
    def spec_classes(self):
        return 3

    @property
    def num_classes(self):
        return 3


In [13]:
## 编写测试,包括用户上传单图和多图
import cv2
import numpy as np
import torch
import torch.nn as nn
from pathlib2 import Path
from torch.utils.data import DataLoader, SequentialSampler
import time


def main(batch_size, img_size, data_path, resume, output_path):
    data_path = Path(data_path)
    output_path = Path(output_path)
    assert data_path.exists()
    if not output_path.exists():
        output_path.mkdir(parents=True)
    
    transform = MedicalTransform(output_size=img_size)
    dataset = TestSet(root=data_path, stack_num=3, img_size=img_size, transform=transform)
    print(dataset.img_channels)
    net = DenseUNet(in_ch=dataset.spec_classes, out_ch=dataset.img_channels)
    #加载模型
    cp_file = Path(resume)
    assert cp_file.exists()

    data = {'net': net}
    load_params(data, cp_file)
    print(f'{" Start evaluation ":-^40s}\n')
    msg = f'Net: {net.__class__.__name__}\n' + \
          f'Batch size: {batch_size}\n'
    print(msg)

    net.eval()
    torch.set_grad_enabled(False)
    sampler = SequentialSampler(dataset)
    data_loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=1, pin_memory=True)

    for batch_idx, data in enumerate(data_loader):
        imgs = data['image'].cpu()
        s = time.time()
        outputs = net(imgs)
        print("total time:", time.time()-s)
        predicts = outputs['output']
        predicts = predicts.argmax(dim=1)
        predicts = predicts.cpu().detach().numpy()
        # print(np.unique(predicts))
        data['predict'] = predicts
        print(data['image'].shape, predicts.shape)
        data = dataset.vis_transform(data)
        
        imgs, predicts = data['image'], data['predict']
        # print(imgs.shape, predicts.shape)
        # imshow(title='test', imgs=(imgs[0, 1], predicts[0]),shape=(1, 2), subtitle=('image', 'predict'))






In [14]:
main(1, (320, 320), 'data/0/imaging', 'checkpoint/best.pth','out')

3
----------- Start evaluation -----------

Net: DenseUNet
Batch size: 1

total time: 4.249403715133667
torch.Size([1, 3, 320, 320]) (1, 320, 320)
total time: 2.7039954662323
torch.Size([1, 3, 320, 320]) (1, 320, 320)
total time: 2.865765333175659
torch.Size([1, 3, 320, 320]) (1, 320, 320)
total time: 2.821321487426758
torch.Size([1, 3, 320, 320]) (1, 320, 320)


KeyboardInterrupt: 