In [0]:
!apt-get install -y -qq software-properties-common python-software-properties module-init-tools
!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!apt-get update -qq 2>&1 > /dev/null
!apt-get -y install -qq google-drive-ocamlfuse fuse
from google.colab import auth
auth.authenticate_user()
from oauth2client.client import GoogleCredentials
creds = GoogleCredentials.get_application_default()
import getpass
!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}


E: Package 'python-software-properties' has no installation candidate
Selecting previously unselected package google-drive-ocamlfuse.
(Reading database ... 144568 files and directories currently installed.)
Preparing to unpack .../google-drive-ocamlfuse_0.7.19-0ubuntu1~ubuntu18.04.1_amd64.deb ...
Unpacking google-drive-ocamlfuse (0.7.19-0ubuntu1~ubuntu18.04.1) ...
Setting up google-drive-ocamlfuse (0.7.19-0ubuntu1~ubuntu18.04.1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
Please, open the following URL in a web browser: https://accounts.google.com/o/oauth2/auth?client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive&response_type=code&access_type=offline&approval_prompt=force
··········
Please, open the following URL in a web browser: https://accounts.google.com/o/oauth2/auth?client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope

# before coding


## 安装谷歌网盘

In [1]:
!mkdir -p drive
!google-drive-ocamlfuse drive  -o nonempty

/bin/bash: google-drive-ocamlfuse: command not found


## 安装关联库

In [1]:
!pip install numpy==1.16.2
!pip install tensorflow==1.12.2
!pip install scikit-learn==0.20.3
!pip install scikit-image==0.14.2
!pip install imageio==2.5.0
!pip install medpy==0.4.0
!pip install Pillow==6.0.0
!pip install scipy==1.2.1
!pip install pandas==0.24.2
!pip install tqdm==4.32.1




In [0]:
from google.colab import drive
drive.mount('/content/drive')

## 更改文件路径

In [0]:
import sys
sys.path.append('/content/drive/My Drive/Project/Brain Segmentation/Network Dissection/Network_Dissection_Version5')

# Coding

## 关联包

In [0]:
import argparse
import json
import os
import h5py

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as Data
import torch.nn.functional as F


from torch.utils.data import DataLoader
from tqdm import tqdm

from dataset import BrainSegmentationDataset as Dataset
from logger import Logger
from loss import DiceLoss
from transform import transforms
from unet import UNet
from utils import log_images, dsc
from IPython import embed

## Data loaders

In [0]:
def data_loaders(args):
    dataset_train, dataset_valid = datasets(args)

    def worker_init(worker_id):
        np.random.seed(42 + worker_id)

    loader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
        worker_init_fn=worker_init,
    )
    loader_valid = DataLoader(
        dataset_valid,
        batch_size=args.batch_size,
        drop_last=False,
        num_workers=args.workers,
        worker_init_fn=worker_init,
    )

    return loader_train, loader_valid


## Datasets

In [0]:
def datasets(args):
    train = Dataset(
        images_dir=args.images,
        subset="train",
        image_size=args.image_size,
        # transform=transforms(scale=args.aug_scale, angle=args.aug_angle, flip_prob=0.5),
    )
    valid = Dataset(
        images_dir=args.images,
        subset="validation",
        image_size=args.image_size,
        # random_sampling=False,
    )
    return train, valid


## Dice coefficient

In [0]:
def dsc_per_volume(validation_pred, validation_true, patient_slice_index):
    dsc_list = []
    label_dsc = []
    num_slices = np.bincount([p[0] for p in patient_slice_index])
    index = 0
    for p in range(len(num_slices)):
        y_pred = np.array(validation_pred[index : index + num_slices[p]])
        y_true = np.array(validation_true[index : index + num_slices[p]])
        overall_dsc, label = dsc(y_pred,y_true) 
        dsc_list.append(overall_dsc)
        label_dsc.append(label)
        index += num_slices[p]
    return dsc_list,label_dsc


## log

In [0]:
def log_loss_summary(logger, loss, step, prefix=""):
    logger.scalar_summary(prefix + "loss", np.mean(loss), step)


def makedirs(args):
    os.makedirs(args.weights, exist_ok=True)
    os.makedirs(args.logs, exist_ok=True)


def snapshotargs(args):
    args_file = os.path.join(args.logs, "args.json")
    with open(args_file, "w") as fp:
        json.dump(vars(args), fp)

## 主函数

In [0]:
def main(args):
    makedirs(args)
    snapshotargs(args)
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

    loader_train, loader_valid = data_loaders(args)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    # unet.apply(weights_init)
    unet.to(device)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=args.lr, weight_decay=1e-4)
    # optimizer = optim.Adam(unet.parameters(), lr=args.lr)

    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    validation_pred = []
    validation_true = []
    step = 0

    for epoch in tqdm(range(args.epochs), total=args.epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()


            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1
                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)
                    loss = dsc_loss(y_pred, y_true)
                    # print(loss)

                    if phase == "valid":
                    # if phase == "train":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                        )
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
                        )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()


            if phase == "valid":
                dsc, label_dsc = dsc_per_volume(
                          validation_pred,
                          validation_true,
                          loader_valid.dataset.patient_slice_index,
                          # loader_train.dataset.patient_slice_index,
                          )
                mean_dsc = np.mean(dsc)
                print(mean_dsc)
                print(np.array(label_dsc).mean(axis=0))

                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    best_label_dsc = label_dsc
                    torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
                    opt = epoch;
                loss_valid = []
                validation_pred = []
                validation_true = []
    # torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
    print(opt);

## 主函数参数

In [9]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Training U-Net model for segmentation of brain MRI"
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="input batch size for training (default: 16)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=1000,
        help="number of epochs to train (default: 100)",
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.0005,
        help="initial learning rate (default: 0.001)",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda:0",
        help="device for training (default: cuda:0)",
    )
    parser.add_argument(
        "--workers",
        type=int,
        default=4,
        help="number of workers for data loading (default: 4)",
    )
    parser.add_argument(
        "--vis-images",
        type=int,
        default=200,
        help="number of visualization images to save in log file (default: 200)",
    )
    parser.add_argument(
        "--vis-freq",
        type=int,
        default=10,
        help="frequency of saving images to log file (default: 10)",
    )
    parser.add_argument(
        "--weights", type=str, default="./weights", help="folder to save weights"
    )
    parser.add_argument(
        "--logs", type=str, default="./logs", help="folder to save logs"
    )
    # parser.add_argument(
    #     "--images", type=str, default="./kaggle_3m", help="root folder with images"
    # )
    parser.add_argument(
        "--images", type=str, default="/content/drive/My Drive/Project/DATA/TEST_DATA/HGG2", help="root folder with images"
        # "--images", type=str, default="/content/drive/My Drive/Project/DATA/TEST_DATA/HGG2/BraTS19_2013_10_1", help="root folder with images"
    )
    parser.add_argument(
        "--image-size",
        type=int,
        default=256,
        help="target input image size (default: 240)",
    )
    parser.add_argument(
        "--aug-scale",
        type=int,
        default=0.05,
        help="scale factor range for augmentation (default: 0.05)",
    )
    parser.add_argument(
        "--aug-angle",
        type=int,
        default=15,
        help="rotation angle range in degrees for augmentation (default: 15)",
    )

    # args = parser.parse_args()
    args = parser.parse_args(args=[])
    main(args)

reading train images...
preprocessing train volumes...
one hotting train volumes...
cropping train volumes...
padding train volumes...
resizing train volumes...


  .format(dtypeobj_in, dtypeobj_out))


finish
finish
finish
finish
finish
finish
finish
finish
finish
normalizing train volumes...
converting mask to one of train volumes...
done creating train dataset
reading validation images...
preprocessing validation volumes...
one hotting validation volumes...
cropping validation volumes...
padding validation volumes...
resizing validation volumes...
finish
finish
normalizing validation volumes...
converting mask to one of validation volumes...
done creating validation dataset


  0%|          | 1/1000 [00:04<1:14:43,  4.49s/it]

0.2699138542154802
[9.48653664e-01 2.86271632e-05 2.67646276e-02 1.04208498e-01]


  0%|          | 2/1000 [00:08<1:13:10,  4.40s/it]

0.4019279918658001
[0.95755685 0.04746306 0.45705226 0.1456398 ]


  0%|          | 3/1000 [00:12<1:12:06,  4.34s/it]

0.4695758026496647
[0.97588509 0.0970184  0.66055497 0.14484475]


  0%|          | 4/1000 [00:17<1:11:17,  4.29s/it]

0.47030192659098846
[0.9735393  0.09590285 0.61550415 0.1962614 ]


  0%|          | 5/1000 [00:21<1:10:23,  4.25s/it]

0.42038560080540244
[0.95019887 0.12655126 0.48468199 0.12011029]


  1%|          | 6/1000 [00:25<1:10:08,  4.23s/it]

0.4750041051895448
[0.97110076 0.01880336 0.68388362 0.22622868]


  1%|          | 7/1000 [00:29<1:09:38,  4.21s/it]

0.4733953827104828
[0.971417   0.0113925  0.63582919 0.27494284]


  1%|          | 8/1000 [00:33<1:09:18,  4.19s/it]

0.46064700808822134
[0.9676096  0.0699984  0.63787319 0.16710684]


  1%|          | 9/1000 [00:37<1:09:06,  4.18s/it]

0.46740992290422534
[0.97101954 0.08266894 0.60532986 0.21062135]


  1%|          | 10/1000 [00:42<1:08:52,  4.17s/it]

0.44836106199965375
[0.97023159 0.05440849 0.65237379 0.11643038]


  1%|          | 11/1000 [00:46<1:08:45,  4.17s/it]

0.4517402551909747
[0.97187924 0.15035608 0.5279658  0.1567599 ]


  1%|          | 12/1000 [00:50<1:08:43,  4.17s/it]

0.44564491021585684
[9.65936057e-01 7.47224044e-05 6.63703875e-01 1.52864986e-01]


  1%|▏         | 13/1000 [00:54<1:08:40,  4.17s/it]

0.4571781187798962
[0.97125333 0.01369989 0.6602845  0.18347476]


  1%|▏         | 14/1000 [00:58<1:08:33,  4.17s/it]

0.4600554070888748
[0.96485707 0.09945195 0.59206174 0.18385087]


  2%|▏         | 15/1000 [01:02<1:08:31,  4.17s/it]

0.4567249974393761
[0.97364091 0.00313981 0.68696879 0.16315048]


  2%|▏         | 16/1000 [01:07<1:08:19,  4.17s/it]

0.4581475459585216
[0.97377771 0.00601781 0.59445841 0.25833625]


  2%|▏         | 17/1000 [01:11<1:08:10,  4.16s/it]

0.461399703008067
[0.97331284 0.05942983 0.59644408 0.21641206]


  2%|▏         | 18/1000 [01:15<1:08:06,  4.16s/it]

0.4506782404488828
[0.96962187 0.00265647 0.60567164 0.22476299]


  2%|▏         | 19/1000 [01:19<1:08:12,  4.17s/it]

0.47501233956853006
[0.97245094 0.06245241 0.63178694 0.23335906]


  2%|▏         | 20/1000 [01:23<1:08:03,  4.17s/it]

0.44273305827860004
[0.97000458 0.00280797 0.65213193 0.14598776]


  2%|▏         | 21/1000 [01:27<1:07:50,  4.16s/it]

0.465841618121051
[0.97310724 0.04243345 0.63123901 0.21658678]


  2%|▏         | 22/1000 [01:32<1:07:59,  4.17s/it]

0.4851096940751172
[0.9725746  0.03783975 0.67473558 0.25528885]


  2%|▏         | 23/1000 [01:36<1:07:54,  4.17s/it]

0.46056074723211293
[0.96856465 0.04523206 0.60199475 0.22645153]


  2%|▏         | 24/1000 [01:40<1:07:45,  4.17s/it]

0.45495604447177107
[0.9729135  0.01180388 0.66067488 0.17443192]


  2%|▎         | 25/1000 [01:44<1:07:34,  4.16s/it]

0.46753151400767545
[0.97238775 0.01448481 0.64910268 0.23415082]


  3%|▎         | 26/1000 [01:48<1:07:28,  4.16s/it]

0.4582061598708064
[0.97469272 0.04732597 0.56309695 0.247709  ]


  3%|▎         | 27/1000 [01:52<1:07:21,  4.15s/it]

0.4387647951992157
[9.70490000e-01 2.70646595e-05 6.65461360e-01 1.19080756e-01]


  3%|▎         | 28/1000 [01:56<1:07:15,  4.15s/it]

0.4635607852634438
[0.96942177 0.04648469 0.67344832 0.16488836]


  3%|▎         | 29/1000 [02:01<1:07:06,  4.15s/it]

0.46961414420335357
[0.97345313 0.02763722 0.62185875 0.25550748]


  3%|▎         | 30/1000 [02:05<1:07:04,  4.15s/it]

0.4742456549083352
[0.97584661 0.01474484 0.6385655  0.26782566]


  3%|▎         | 31/1000 [02:09<1:07:03,  4.15s/it]

0.4487834659936416
[9.70919150e-01 2.85610801e-05 6.31291428e-01 1.92894725e-01]


  3%|▎         | 32/1000 [02:13<1:07:04,  4.16s/it]

0.4612288855268962
[0.97032317 0.05747267 0.6367897  0.18033001]


  3%|▎         | 33/1000 [02:17<1:07:03,  4.16s/it]

0.4273753059196837
[0.96374221 0.00238742 0.60135893 0.14201266]


  3%|▎         | 34/1000 [02:21<1:06:56,  4.16s/it]

0.4597733790198526
[0.97268218 0.00765736 0.61368067 0.24507331]


  4%|▎         | 35/1000 [02:26<1:06:50,  4.16s/it]

0.43579935105132483
[9.67754477e-01 2.65706055e-05 6.84846622e-01 9.05697351e-02]


  4%|▎         | 36/1000 [02:30<1:06:45,  4.16s/it]

0.44776974054677654
[0.97219609 0.01036831 0.57554465 0.23296991]


  4%|▎         | 37/1000 [02:34<1:06:40,  4.15s/it]

0.4483485301796405
[0.96900952 0.00556473 0.66729335 0.15152652]


  4%|▍         | 38/1000 [02:38<1:06:38,  4.16s/it]

0.46004196575926926
[9.71653084e-01 2.13698414e-04 6.50436668e-01 2.17864413e-01]


  4%|▍         | 39/1000 [02:42<1:06:33,  4.16s/it]

0.45911239140141824
[9.74135547e-01 2.90639430e-04 6.01387912e-01 2.60635467e-01]


KeyboardInterrupt: ignored

In [0]:
embed()

Python 3.6.9 (default, Apr 18 2020, 01:56:04) 
Type "copyright", "credits" or "license" for more information.

IPython 5.5.0 -- An enhanced Interactive Python.
?         -> Introduction and overview of IPython's features.
%quickref -> Quick reference.
help      -> Python's own help system.
object?   -> Details about 'object', use 'object??' for extra details.

In [1]: exit



# Network Dissection

## 加载模型

In [0]:
device = torch.device("cpu" if not torch.cuda.is_available() else args.device)
unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
model = unet.to(device)
model.load_state_dict(torch.load('/content/weights/unet.pt'))

<All keys matched successfully>

## 读取图片

In [0]:
import nibabel as nib

path = '/content/drive/My Drive/Project/DATA/TEST_DATA/HGG2/BraTS19_2013_2_1/BraTS19_2013_2_1_flair.nii.gz'
image_data = nib.load(path).get_fdata()

## 读取MASK

In [0]:
import nibabel as nib

path = '/content/drive/My Drive/Project/DATA/TEST_DATA/HGG2/BraTS19_2013_2_1/BraTS19_2013_2_1_seg.nii.gz'
mask_data = nib.load(path).get_data()


* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  after removing the cwd from sys.path.


## 打印layer_name

In [0]:
for name in model.state_dict():
   print(name)

encoder1.enc1conv1.weight
encoder1.enc1norm1.weight
encoder1.enc1norm1.bias
encoder1.enc1norm1.running_mean
encoder1.enc1norm1.running_var
encoder1.enc1norm1.num_batches_tracked
encoder1.enc1conv2.weight
encoder1.enc1norm2.weight
encoder1.enc1norm2.bias
encoder1.enc1norm2.running_mean
encoder1.enc1norm2.running_var
encoder1.enc1norm2.num_batches_tracked
encoder2.enc2conv1.weight
encoder2.enc2norm1.weight
encoder2.enc2norm1.bias
encoder2.enc2norm1.running_mean
encoder2.enc2norm1.running_var
encoder2.enc2norm1.num_batches_tracked
encoder2.enc2conv2.weight
encoder2.enc2norm2.weight
encoder2.enc2norm2.bias
encoder2.enc2norm2.running_mean
encoder2.enc2norm2.running_var
encoder2.enc2norm2.num_batches_tracked
encoder3.enc3conv1.weight
encoder3.enc3norm1.weight
encoder3.enc3norm1.bias
encoder3.enc3norm1.running_mean
encoder3.enc3norm1.running_var
encoder3.enc3norm1.num_batches_tracked
encoder3.enc3conv2.weight
encoder3.enc3norm2.weight
encoder3.enc3norm2.bias
encoder3.enc3norm2.running_mean
en

## Dissection

### 关联库

In [0]:
import argparse
import os

import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
from medpy.filter.binary import largest_connected_component
from skimage.io import imsave
from skimage.transform import resize
from torch.utils.data import DataLoader
from tqdm import tqdm

from dataset import BrainSegmentationDataset as Dataset
from unet import UNet
from utils import dsc, gray2rgb, outline

### function for dissection

In [0]:
def main(args):
    makedirs(args)
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

    loader = data_loader(args)

    with torch.set_grad_enabled(False):
        unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
        state_dict = torch.load(args.weights, map_location=device)
        unet.load_state_dict(state_dict)
        unet.eval()
        unet.to(device)

        input_list = []
        pred_list = []
        true_list = []
        Dice_list = []

        for i, data in tqdm(enumerate(loader)):
            x, y_true = data
            y_true=convert_mask(y_true)
            x, y_true = x.to(device), y_true.to(device)

            y_pred = unet(x)
            y_pred_np = y_pred.detach().cpu().numpy()
# Network Dissection

            # feature_map = unet.featuremap_enc1.transpose(1,0).cpu().numpy()
            feature_map = unet.featuremap_dec4.transpose(1,0).cpu().numpy()
            threshold_maps = np.percentile(feature_map, 95, axis=(2,3),keepdims=True)
            masks = feature_map > threshold_maps
            C,B,H,W = masks.shape
            masks = resize(masks,output_shape=(C,B,256,256),order=0,mode="edge",anti_aliasing=False) 
            y_true = y_true.detach().cpu().numpy()
            Dice_list.append(dsc(masks,y_true))
            embed()

            
        dice = 0
        for data in Dice_list:
          dice += data
        dice /= len(Dice_list)

        embed()
        print(dice)

def dsc(masks,gt):

    C,B,H,W = masks.shape
    masks=masks.reshape(C,B,1,H,W)
    B,L,H,W = gt.shape
    gt=gt.reshape(1,B,L,H,W)

    # embed()
    intersection = (gt*masks).sum(axis=(1,3,4))
    dsc = (2. * intersection + 1) / (masks.sum(axis=(1,3,4)) + gt.sum(axis=(1,3,4)) + 1)

    return dsc

def convert_mask(mask):
  # 4 Gd-enhance tumour
  # 2 peritumoral edema
  # 1 necrotic and non-enhancing tumour
  B,L,H,W = mask.shape

  label_1 = (mask[:,1,:,:]+mask[:,2,:,:]).reshape(B,1,H,W)              #1+2
  label_2 = (mask[:,1,:,:]+mask[:,3,:,:]).reshape(B,1,H,W)              #1+4
  label_3 = (mask[:,2,:,:]+mask[:,3,:,:]).reshape(B,1,H,W)              #2+4
  label_4 = (mask[:,1,:,:]+mask[:,2,:,:]+mask[:,3,:,:]).reshape(B,1,H,W)       #1+2+4

  extra_label = torch.cat((label_1,label_2,label_3,label_4),1)

  mask = torch.cat((mask,extra_label),1)
  
  return mask


def data_loader(args):
    dataset = Dataset(
        images_dir=args.images,
        subset="train",
        image_size=args.image_size,
        random_sampling=False,
    )
    loader = DataLoader(
        dataset, batch_size=args.batch_size, drop_last=False, num_workers=1
    )
    return loader


def makedirs(args):
    os.makedirs(args.predictions, exist_ok=True)


### 参数

In [0]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Inference for segmentation of brain MRI"
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda:0",
        help="device for training (default: cuda:0)",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=8,
        help="input batch size for training (default: 32)",
    )
    parser.add_argument(
        "--weights", type=str, default = "/content/weights/unet.pt", help="path to weights file"
    )
    parser.add_argument(
        # "--images", type=str, default="/content/drive/My Drive/Project/DATA/TEST_DATA/HGG2", help="root folder with images"
        "--images", type=str, default="/content/drive/My Drive/Project/DATA/TEST_DATA/HGG2/BraTS19_2013_11_1", help="root folder with images"
    )
    parser.add_argument(
        "--image-size",
        type=int,
        default=256,
        help="target input image size (default: 256)",
    )
    parser.add_argument(
        "--predictions",
        type=str,
        default="/content/Results",
        help="folder for saving images with prediction outlines",
    )
    parser.add_argument(
        "--figure",
        type=str,
        default="./dsc.png",
        help="filename for DSC distribution figure",
    )

    args = parser.parse_args(args=[])
    main(args)


reading train images...
preprocessing train volumes...
one hotting train volumes...
cropping train volumes...
padding train volumes...
resizing train volumes...


  .format(dtypeobj_in, dtypeobj_out))


finish
normalizing train volumes...
converting mask to one of train volumes...
done creating train dataset




0it [00:00, ?it/s][A[A

Python 3.6.9 (default, Apr 18 2020, 01:56:04) 
Type "copyright", "credits" or "license" for more information.

IPython 5.5.0 -- An enhanced Interactive Python.
?         -> Introduction and overview of IPython's features.
%quickref -> Quick reference.
help      -> Python's own help system.
object?   -> Details about 'object', use 'object??' for extra details.

In [1]: y_pred_np.shape
Out[1]: (8, 4, 256, 256)

In [2]: y_true.shape
Out[2]: (8, 8, 256, 256)

In [3]: for i in range(8):
   ...:     plt.imshow(y_true[i,1,:,:],cmap='gray')
   ...:     plt.savefig("Figure{0}".format(i))
   ...: 

In [4]: for i in range(8):
   ...:     plt.imshow(y_pred_np[i,1,:,:],cmap='gray')
   ...:     plt.savefig("Pred{0}".format(i))
   ...: 

In [5]: for i in range(8):
   ...:     plt.imshow(y_pred_np[i,1,:,:],cmap='gray')
   ...:     plt.savefig("Pred{0}".format(i))
   ...: 

In [6]: y_pred.shape
Out[6]: torch.Size([8, 4, 256, 256])

In [7]: np.unique(y_pred[5,1,:,:])
[0;31m-----------------------------

  """Entry point for launching an IPython kernel.



In [13]: one_hot
Out[13]: False

In [14]: max_output = np.max(y_pred_np,axis=1)

In [15]: one_hot = (y_pred_np == max_output)


  """Entry point for launching an IPython kernel.



In [16]: one_hot
Out[16]: False

In [17]: max_output.shape
Out[17]: (8, 256, 256)

In [18]: one_hot.shape
[0;31m---------------------------------------------------------------------------[0m
[0;31mAttributeError[0m                            Traceback (most recent call last)
[0;32m<ipython-input-18-311a53ec5a1c>[0m in [0;36m<module>[0;34m()[0m
[0;32m----> 1[0;31m [0mone_hot[0m[0;34m.[0m[0mshape[0m[0;34m[0m[0;34m[0m[0m
[0m
[0;31mAttributeError[0m: 'bool' object has no attribute 'shape'

In [19]: max_output.shape
Out[19]: (8, 256, 256)

In [20]: y_pred.shape
Out[20]: torch.Size([8, 4, 256, 256])

In [21]: max_output.reshape(8,1,256,256)
Out[21]: 
array([[[[0.46463883, 0.4689224 , 0.4684771 , ..., 0.46688044,
          0.46904683, 0.46834275],
         [0.47036812, 0.47078604, 0.4673641 , ..., 0.46712914,
          0.47105184, 0.47003615],
         [0.47006738, 0.46754083, 0.46809888, ..., 0.46784392,
          0.4675883 , 0.46928594],
         ...,
         [0.4

KeyboardInterrupt: ignored

<Figure size 432x288 with 1 Axes>


In [0]:
m, s = np.mean(input_image, axis=(0, 1)), np.std(input_image, axis=(0, 1))

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=m, std=s),
])

input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)

# plt.imshow(input_image.squeeze(), cmap = 'gray')

if torch.cuda.is_available():
  input_batch = input_batch.to('cuda',dtype=torch.float)
  model = model.to('cuda')

with torch.no_grad():
    output = model(input_batch)

embed()
output = output.cpu().numpy()
# max_output = np.max(output,axis=1)
# one_hot = (output == max_output)

max_pred=np.max(output,axis=1)
b,h,w = max_pred.shape
max_pred = max_pred.reshape(b,1,h,w)
one_hot = (output == max_pred)

# plt.imshow(one_hot[0,3,:,:], cmap = 'gray')
# plt.savefig('./test5.jpg')

# 下载模型

In [0]:
from google.colab import files
files.download('/content/weights/unet.pt')

# GPU参数

In [0]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime → "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Thu Apr 23 06:36:49 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.64.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   78C    P0    27W /  75W |   4613MiB /  7611MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
+-------