In [1]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!python -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [2]:
import openfl.native as fx

# Setup default workspace, logging, etc. Install additional requirements
fx.init('torch_unet_kvasir')

Creating Workspace Directories
Creating Workspace Templates


  return torch._C._cuda_getDeviceCount() > 0


Successfully installed packages from /home/maksim/.local/workspace/requirements.txt.

New workspace directory structure:
workspace
├── agg_to_col_one_signed_cert.zip
├── code
│   ├── data_loader.py
│   ├── fed_unet_runner.py
│   ├── pt_unet_parts.py
│   └── __init__.py
├── plan
│   ├── data.yaml
│   ├── plan.yaml
│   ├── cols.yaml
│   └── defaults
│       ├── aggregator.yaml
│       ├── network.yaml
│       ├── assigner.yaml
│       ├── collaborator.yaml
│       ├── tasks_torch.yaml
│       ├── tasks_tensorflow.yaml
│       ├── tasks_keras.yaml
│       ├── tasks_fast_estimator.yaml
│       ├── data_loader.yaml
│       ├── task_runner.yaml
│       └── defaults
├── logs
├── save
│   ├── torch_unet_kvasir_init.pbuf
│   ├── torch_unet_kvasir_last.pbuf
│   └── torch_unet_kvasir_best.pbuf
├── .workspace
├── cert
│   ├── ca
│   │   ├── signing-ca
│   │   ├── signing-ca.csr
│   │   ├── root-ca
│   │   ├── signing-ca.crt
│   │   └── root-ca.crt
│   ├── server
│   │   ├── agg_none.csr
│   │   ├─

In [3]:
!pwd

/home/maksim/.local/workspace


In [4]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import sys
import json

import matplotlib.pyplot as plt
import numpy as np
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader
from monai.data import CacheDataset
from monai.data import Dataset 
from monai.data import (load_decathlon_datalist, load_decathlon_properties)
from monai.transforms import Randomizable
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import (
    Activations,
    AsChannelFirstd,
    AsDiscrete,
    CenterSpatialCropd,
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    ToTensord,
)
from monai.utils import set_determinism

import torch
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union, Dict
from openfl.federated import FederatedModel, FederatedDataSet
from openfl.utilities import TensorKey

# print_config()set

In [5]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(np.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
                )
            )
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d

In [6]:
train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        RandSpatialCropd(
            keys=["image", "label"], roi_size=[128, 128, 64], random_size=False
        ),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ToTensord(keys=["image", "label"]),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AsChannelFirstd(keys="image"),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        CenterSpatialCropd(keys=["image", "label"], roi_size=[128, 128, 64]),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ToTensord(keys=["image", "label"]),
    ]
)

In [7]:
class BrainFederatedDataset(Randomizable, CacheDataset):
    """
    The Dataset to automatically download the data of Medical Segmentation Decathlon challenge
    (http://medicaldecathlon.com/) and generate items for training, validation or test.
    It will also load these properties from the JSON config file of dataset. user can call `get_properties()`
    to get specified properties or all the properties loaded.
    """

    def __init__(self, collaborator_count, collaborator_num, is_validation, transform):
        self.is_validation = is_validation
        dataset_dir = './data/Task01_BrainTumour/' #os.path.join(root_dir, task)

        self.indices: np.ndarray = np.array([])
        if is_validation:
            transform = val_transform
        else:
            transform = train_transform 
        data = self._generate_data_list(dataset_dir)
        data= data[collaborator_num:: collaborator_count]
        print(len(data), 'len')
        self.is_validation = is_validation
        assert(len(data) > 8)
        validation_size = len(data) // 8
        if is_validation:
            data= data[-validation_size:]
        else:
            data= data[: -validation_size]
        property_keys = [
            "name",
            "description",
            "reference",
            "licence",
            "tensorImageSize",
            "modality",
            "labels",
            "numTraining",
            "numTest",
        ]
        self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys)
        super().__init__(data, transform, cache_num=1, num_workers=4)

    def get_indices(self) -> np.ndarray:
        """
        Get the indices of datalist used in this dataset.

        """
        return self.indices

    def randomize(self, data: List[int]) -> None:
        self.R.shuffle(data)

    def _generate_data_list(self, dataset_dir: str) -> List[Dict]:
        datalist = load_decathlon_datalist(os.path.join(dataset_dir, "dataset.json"), True, "training")
        return datalist
    
    def __getitem__(self, index):
        tmp = super().__getitem__(index)
        return (tmp['image'], tmp['label'])


In [8]:
class KvasirFederatedDataset(FederatedDataSet):
    def __init__(self, collaborator_count=1, collaborator_num=0, batch_size=1, is_split=False, **kwargs):
        """Instantiate the data object
        Args:
            collaborator_count: total number of collaborators
            collaborator_num: number of current collaborator
            batch_size:  the batch size of the data loader
            **kwargs: additional arguments, passed to super init
        """
        super().__init__([], [], [], [], batch_size, num_classes=2, **kwargs)

        self.collaborator_num = int(collaborator_num)

        self.batch_size = batch_size

        self.training_set = BrainFederatedDataset(
            collaborator_count, collaborator_num, is_validation=False, transform=None
        )
        self.valid_set = BrainFederatedDataset(
            collaborator_count, collaborator_num, is_validation=True, transform=None
        )

        self.train_loader = self.get_train_loader()
        self.val_loader = self.get_valid_loader()

    def get_valid_loader(self, num_batches=None):
        return DataLoader(self.valid_set, num_workers=2, batch_size=self.batch_size)

    def get_train_loader(self, num_batches=None):
        return DataLoader(
            self.training_set, num_workers=2, batch_size=self.batch_size, shuffle=True
        )

    def get_train_data_size(self):
        return len(self.training_set)

    def get_valid_data_size(self):
        return len(self.valid_set)

    def get_feature_shape(self):
        #print('shape',self.valid_set[2]['image'].shape)
        return self.valid_set[0][0].shape

    def split(self, collaborator_count, shuffle=True, equally=True):
        return [
            KvasirFederatedDataset(collaborator_count,
                           collaborator_num, self.batch_size)
            for collaborator_num in range(collaborator_count)
        ]

In [10]:
fl_data = KvasirFederatedDataset(batch_size=6)

  0%|          | 0/1 [00:00<?, ?it/s]

484 len


100%|██████████| 1/1 [00:01<00:00,  1.08s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

484 len


100%|██████████| 1/1 [00:01<00:00,  1.07s/it]


In [11]:
class UnetWrapper(UNet):
    def __init__(self):
        super().__init__(
            dimensions=3,
            in_channels=4,
            out_channels=3,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,)
        
    def validate(
        self, col_name, round_num, input_tensor_dict, use_tqdm=False, **kwargs
    ):
        """ Validate. Redifine function from PyTorchTaskRunner, to use our validation"""
        self.rebuild_model(round_num, input_tensor_dict, validation=True)
        dice_metric = DiceMetric(include_background=True, reduction="mean")
        self.eval()
        self.to(self.device)
        val_score = 0
        total_samples = 0

        loader = self.data_loader.get_valid_loader()
        if use_tqdm:
            loader = tqdm.tqdm(loader, desc="validate")
#-------------User code---------------------------------------------------------------------------
        dice_metric = DiceMetric(include_background=True, reduction="mean")
        post_trans = Compose(
            [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]
        )
        metric_sum = metric_sum_tc = metric_sum_wt = metric_sum_et = 0.0
        metric_count = (
            metric_count_tc
        ) = metric_count_wt = metric_count_et = 0
        
        with torch.no_grad():
            for val_inputs,val_labels in loader:
                
                val_outputs = self(val_inputs)
                val_outputs = post_trans(val_outputs)
                # compute overall mean dice
                value, not_nans = dice_metric(y_pred=val_outputs, y=val_labels)
                not_nans = not_nans.item()
                metric_count += not_nans
                metric_sum += value.item() * not_nans

            metric = metric_sum / metric_count
#--------------------------------------------------------------------------                
                
                

        origin = col_name
        suffix = "validate"
        if kwargs["apply"] == "local":
            suffix += "_local"
        else:
            suffix += "_agg"
        tags = ("metric", suffix)
        output_tensor_dict = {
            TensorKey("dice_coef", origin, round_num, True, tags): np.array(
                metric
            )
        }
        return output_tensor_dict, {}
    
#Wrapper, because our train_batches set (output, target) args, but DiceLoss  recieve (input, target)
class DiceLossHeir(DiceLoss):
    __name__ = 'DiceLoss'
    def forward(self,output, target):
        return super().forward(input=output, target=target)
    
loss_function = DiceLossHeir(to_onehot_y=False, sigmoid=True, squared_pred=True)

def optimizer(x): return torch.optim.Adam(
    x, 1e-4, weight_decay=1e-5, amsgrad=True
)



<class '__main__.UnetWrapper'>


In [None]:

fl_model = FederatedModel(build_model=UnetWrapper, optimizer=optimizer,
                          loss_fn=loss_function, data_loader=fl_data)

In [12]:
collaborator_models = fl_model.setup(num_collaborators=2)
collaborators = {'one': collaborator_models[0], 'two': collaborator_models[1]}

  0%|          | 0/1 [00:00<?, ?it/s]

242 len


100%|██████████| 1/1 [00:01<00:00,  1.05s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

242 len


100%|██████████| 1/1 [00:01<00:00,  1.10s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

242 len


100%|██████████| 1/1 [00:01<00:00,  1.11s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

242 len


100%|██████████| 1/1 [00:01<00:00,  1.11s/it]

<class '__main__.UnetWrapper'>





<class '__main__.UnetWrapper'>


In [13]:
# Get the current values of the FL plan. Each of these can be overridden
#print(json.dumps(fx.get_plan(), indent=4, sort_keys=True))

In [14]:
# Run experiment, return trained FederatedModel
final_fl_model = fx.run_experiment(
    collaborators, override_config={'aggregator.settings.rounds_to_train': 30})

  new_state[k] = pt.from_numpy(tensor_dict.pop(k)).to(device)


torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])


  data, target = pt.tensor(data).to(self.device), pt.tensor(
  target).to(self.device, dtype=pt.float32)


after loss tensor(0.9679, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9738, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9539, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9737, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9372, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9513, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9563, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9706, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9455, grad_fn=<MeanBackward0>)
torch.Size([6

torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9472, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9764, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9705, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9602, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9552, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9452, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9393, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9443, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64]

torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9014, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9315, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9117, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9503, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9287, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9252, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9314, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9626, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64]

torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9541, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9239, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.8987, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9393, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9345, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9319, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9443, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9426, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64]

torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9518, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9234, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9146, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9293, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9389, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9502, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.8968, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9458, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64]

torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9147, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9265, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9224, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9477, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9353, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9032, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9169, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9437, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64]

torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9728, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9449, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9087, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9005, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9011, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9004, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9235, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9190, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64]

torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9594, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9531, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9369, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.8785, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9146, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9262, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.8992, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9166, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64]

torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9305, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9311, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9217, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9428, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.8814, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9417, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9045, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64])
after loss tensor(0.9260, grad_fn=<MeanBackward0>)
torch.Size([6, 4, 128, 128, 64]) torch.Size([6, 3, 128, 128, 64]

KeyboardInterrupt: 

In [None]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
