## IMPORTS

In [1]:
import numpy as np
from monai.data.image_reader import ImageReader, ITKReader
from ipywidgets.widgets import *
import ipywidgets as widgets
import matplotlib.pyplot as plt

import pytorch_lightning
from monai.utils import set_determinism
from monai.transforms import (
    AsDiscrete,
    AddChanneld,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    EnsureChannelFirstd,
    Resized,
)
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, list_data_collate, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob

from monai.transforms.spatial.array import Resize

torch.cuda.empty_cache()
print_config()


MONAI version: 0.8.1
Numpy version: 1.21.5
Pytorch version: 1.11.0
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 71ff399a3ea07aef667b23653620a290364095b1

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 3.2.2
scikit-image version: 0.19.2
Pillow version: 9.0.1
Tensorboard version: 2.8.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.12.0
tqdm version: 4.64.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.0
pandas version: 1.4.1
einops version: 0.4.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



# CUSTOM TRANSFORMATION

In [2]:
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

from monai.config import DtypeLike, KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from monai.networks.layers import AffineTransform
from monai.networks.layers.simplelayers import GaussianFilter
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.array import (
    AddCoordinateChannels,
    Affine,
    AffineGrid,
    Flip,
    GridDistortion,
    Orientation,
    Rand2DElastic,
    Rand3DElastic,
    RandAffine,
    RandAxisFlip,
    RandFlip,
    RandGridDistortion,
    RandRotate,
    RandZoom,
    Resize,
    Rotate,
    Rotate90,
    Spacing,
    Zoom,
)
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.transforms.utils import create_grid
from monai.utils import (
    GridSampleMode,
    GridSamplePadMode,
    InterpolateMode,
    NumpyPadMode,
    PytorchPadMode,
    ensure_tuple,
    ensure_tuple_rep,
    fall_back_tuple,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import TraceKeys
from monai.utils.module import optional_import
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
from monai.apps import load_from_mmar
from monai.apps.mmars import RemoteMMARKeys
from monai.networks.utils import copy_model_state
from monai.optimizers import generate_param_groups


class InterpolateMode(Enum):
    NEAREST = "nearest"
    LINEAR = "linear"
    BILINEAR = "bilinear"
    BICUBIC = "bicubic"
    TRILINEAR = "trilinear"
    AREA = "area"


InterpolateModeSequence = Union[
    Sequence[Union[InterpolateMode, str]], InterpolateMode, str
]


class ResizedC(MapTransform, InvertibleTransform):

    backend = Resize.backend

    def __init__(
        self,
        keys: KeysCollection,
        spatial_size: Union[Sequence[int], int],
        size_mode: str = "all",
        mode: InterpolateModeSequence = InterpolateMode.AREA,
        align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)
        self.mode = ensure_tuple_rep(mode, len(self.keys))
        self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
        self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode)

    def __call__(
        self, data: Mapping[Hashable, NdarrayOrTensor]
    ) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key, mode, align_corners in self.key_iterator(
            d, self.mode, self.align_corners
        ):
            self.push_transform(
                d,
                key,
                extra_info={
                    "mode": mode.value if isinstance(mode, Enum) else mode,
                    "align_corners": align_corners
                    if align_corners is not None
                    else TraceKeys.NONE,
                },
            )
            if key == "label":
                label = d[key]
                print("Old shape de la label", np.shape(d[key]))
                # img = dict(img)
                for i, channel in enumerate([1, 2]):
                    if i == 0:
                        device = torch.device("cuda")
                        background = torch.zeros(
                            1, label.shape[1], label.shape[2], label.shape[3]
                        )
                        new_image = np.expand_dims(label[channel, :, :, :], 0)
                        # if len(new_image.shape) <6 :
                        #     new_image = new_image.unsqueeze(0)
                        print(background.shape)
                        print(new_image.shape)
                        new_image = np.concatenate((background, new_image), axis=0)
                    else:
                        new_image = np.concatenate(
                            (new_image, np.expand_dims(label[channel, :, :, :], 0)),
                            axis=0,
                        )

                print("Shape de la label", new_image.shape)
                print("labels")
                #                 print(np.shape(d[key]))
                resized = list()
                for channel in new_image:
                    print(np.shape(channel))
                    resized.append(
                        self.resizer(
                            np.expand_dims(channel, 0), align_corners=align_corners
                        )
                    )
                d[key] = np.stack(resized).astype(np.float32).squeeze
            else:
                print("img")
                print(np.shape(d[key]))
                d[key] = self.resizer(d[key], align_corners=align_corners)
        return d

    def inverse(
        self, data: Mapping[Hashable, NdarrayOrTensor]
    ) -> Dict[Hashable, NdarrayOrTensor]:
        d = deepcopy(dict(data))
        for key in self.key_iterator(d):
            transform = self.get_most_recent_transform(d, key)
            orig_size = transform[TraceKeys.ORIG_SIZE]
            mode = transform[TraceKeys.EXTRA_INFO]["mode"]
            align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"]
            # Create inverse transform
            inverse_transform = Resize(
                spatial_size=orig_size,
                mode=mode,
                align_corners=None
                if align_corners == TraceKeys.NONE
                else align_corners,
            )
            # Apply inverse transform
            d[key] = inverse_transform(d[key])
            # Remove the applied transform
            self.pop_transform(d, key)

        return d


### model
Usa Pytorch Lightning pero no usa las funciones internas.

In [3]:
PRETRAINED = True
TRANSFER_LEARNING = True


class Net(pytorch_lightning.LightningModule):
    def __init__(self):
        super().__init__()
        if PRETRAINED:
            print("using a pretrained model.")
            unet_model = load_from_mmar(
                item=mmar[RemoteMMARKeys.NAME],
                mmar_dir=root_dir,
                # map_location=device,
                pretrained=True,
            )
            self._model = unet_model
            # copy all the pretrained weights except for variables whose name matches "model.0.conv.unit0"
            if TRANSFER_LEARNING:
                pretrained_dict, updated_keys, unchanged_keys = copy_model_state(
                    self._model, unet_model, exclude_vars="model.0.conv.unit0"
                )
                print(
                    "num. var. using the pretrained",
                    len(updated_keys),
                    ", random init",
                    len(unchanged_keys),
                    "variables.",
                )
                self._model.load_state_dict(pretrained_dict)
                # stop gradients for the pretrained weights
                for x in self._model.named_parameters():
                    if x[0] in updated_keys:
                        x[1].requires_grad = False
                params = generate_param_groups(
                    network=self._model,
                    layer_matches=[lambda x: x[0] in updated_keys],
                    match_types=["filter"],
                    lr_values=[1e-4],
                    include_others=False,
                )
                self.params = params

        else:
            self._model = UNet(
                spatial_dims=3,
                in_channels=1,
                out_channels=2,
                channels=(16, 32, 64, 128, 256),
                strides=(2, 2, 2, 2),
                num_res_units=2,
                norm=Norm.BATCH,
            )
        self.loss_function = DiceLoss(to_onehot_y=True, softmax=True)
        self.post_pred = Compose(
            [EnsureType("tensor", device="cpu"), AsDiscrete(argmax=True, to_onehot=2)]
        )
        self.post_label = Compose(
            [EnsureType("tensor", device="cpu"), AsDiscrete(to_onehot=2)]
        )
        self.dice_metric = DiceMetric(
            include_background=False, reduction="mean", get_not_nans=False
        )
        self.best_val_dice = 0
        self.best_val_epoch = 0

    def forward(self, x):
        return self._model(x)

    def prepare_data(self):
        # set up the correct data path
        train_images = sorted(
            glob.glob(
                os.path.join(
                    "U:", "\lauraalvarez", "data", "liver", "train", "scans", "*.mha"
                )
            )
        )
        train_labels = sorted(
            glob.glob(
                os.path.join(
                    "U:", "\lauraalvarez", "data", "liver", "train", "overlays", "*.mha"
                )
            )
        )  # U:\lauraalvarez\data\overlays\overlay_results\overlay_results\liver\train
        data_dicts = [
            {"image": image_name, "label": label_name}
            for image_name, label_name in zip(train_images, train_labels)
        ]
        train_files, val_files = [data_dicts[0]], [data_dicts[0]]
        print("len(train_files)", len(train_files))

        # set deterministic training for reproducibility
        set_determinism(seed=0)

        # define the data transforms
        train_transforms = Compose(
            [
                LoadImaged(keys=["image", "label"], reader=ITKReader),
                EnsureChannelFirstd(keys=["image", "label"]),
                ResizedC(keys=["image"], spatial_size=(96, 96, 96)),
                ResizedC(keys=["label"], spatial_size=(96, 96, 96)),
                # Orientationd(keys=["image", "label"], axcodes="PLI"),
                # AddChanneld(keys=["label"]),
                ScaleIntensityRanged(
                    keys=["image"],
                    a_min=-57,
                    a_max=164,
                    b_min=0.0,
                    b_max=1.0,
                    clip=True,
                ),
            ]
        )
        # val_transforms = Compose(
        #     [
        #         LoadImaged(keys=["image", "label"], reader=ITKReader),
        #         EnsureChannelFirstd(keys=["image", "label"]),
        #         # SelectChannels(keys=["label"], channel_dims=[1,2]),
        #         Orientationd(keys=["image", "label"], axcodes="PLI"),
        #         # AddChanneld(keys=["image", "label"]),
        #         ScaleIntensityRanged(
        #             keys=["image"], a_min=-57, a_max=164,
        #             b_min=0.0, b_max=1.0, clip=True,
        #         ),
        #     ]
        # )

        # we use cached datasets - these are 10x faster than regular datasets
        self.train_ds = CacheDataset(
            data=train_files,
            transform=train_transforms,
            cache_rate=1.0,
            num_workers=20,
        )
        # self.val_ds = CacheDataset(
        #     data=val_files, transform=val_transforms,
        #     cache_rate=1.0, num_workers=20,
        # )
        # print('len(self.val_ds)', len(self.val_ds))

    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(
            self.train_ds,
            batch_size=1,
            shuffle=True,
            num_workers=1,
            collate_fn=list_data_collate,
        )
        return train_loader

    # def val_dataloader(self):
    #     val_loader = torch.utils.data.DataLoader(
    #         self.val_ds, batch_size=1, num_workers=1)
    #     return val_loader

    def predict_dataloader(self):
        print("predicting dataload...")
        val_loader = torch.utils.data.DataLoader(
            self.train_ds, batch_size=1, num_workers=1
        )
        print("len(val_loader)", len(val_loader))
        return val_loader

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self._model.parameters(), 1e-4)
        if TRANSFER_LEARNING:
            optimizer = torch.optim.Adam(self.params, 1e-5)
        return optimizer

    def training_step(self, batch, batch_idx):
        print("training started..")
        images, labels = batch["image"], batch["label"]
        # labels = SelectChannels([1, 2], labels)
        print("forwarding..")
        output = self.forward(images)
        print("calculating loss")
        loss = self.loss_function(output, labels)
        tensorboard_logs = {"train_loss": loss.item()}
        return {"loss": loss, "log": tensorboard_logs}

    def predict_step(self, batch, batch_idx):
        print("predicting...")
        images, labels = batch["image"], batch["label"]
        print("images.shape", images.shape)
        print("labels.shape", labels.shape)
        # labels = SelectChannels([1, 2], labels)
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        outputs = sliding_window_inference(
            images, roi_size, sw_batch_size, self.forward
        )
        # outputs = self.forward(images)
        return {"output": outputs, "image": images, "label": labels}

    # def validation_step(self, batch, batch_idx):
    #     images, labels = batch["image"], batch["label"]
    #     labels = SelectChannels([1,2], labels)
    #     roi_size = (160, 160, 160)
    #     sw_batch_size = 4
    #     outputs = sliding_window_inference(
    #         images, roi_size, sw_batch_size, self.forward)
    #     loss = self.loss_function(outputs, labels)
    #     outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
    #     labels = [self.post_label(i) for i in decollate_batch(labels)]
    #     self.dice_metric(y_pred=outputs, y=labels)
    #     return {"val_loss": loss, "val_number": len(outputs)}

    # def validation_epoch_end(self, outputs):
    #     val_loss, num_items = 0, 0
    #     for output in outputs:
    #         val_loss += output["val_loss"].sum().item()
    #         num_items += output["val_number"]
    #     mean_val_dice = self.dice_metric.aggregate().item()
    #     self.dice_metric.reset()
    #     mean_val_loss = torch.tensor(val_loss / num_items)
    #     tensorboard_logs = {
    #         "val_dice": mean_val_dice,
    #         "val_loss": mean_val_loss,
    #     }
    #     if mean_val_dice > self.best_val_dice:
    #         self.best_val_dice = mean_val_dice
    #         self.best_val_epoch = self.current_epoch
    #     print(
    #         f"current epoch: {self.current_epoch} "
    #         f"current mean dice: {mean_val_dice:.4f}"
    #         f"\nbest mean dice: {self.best_val_dice:.4f} "
    #         f"at epoch: {self.best_val_epoch}"
    #     )
    #     return {"log": tensorboard_logs}


### data loader

In [4]:
import torchio as tio
import os 
import glob
from tqdm.notebook import tqdm

In [5]:

def prepare_data():
    # set up the correct data path
    train_images = sorted(
        glob.glob(os.path.join("U:","\lauraalvarez","data","liver", "train", "scans", "*.mha")))
    train_labels = sorted(
        glob.glob(os.path.join("U:","\lauraalvarez","data","liver", "train", "overlays", "*.mha"))) # U:\lauraalvarez\data\overlays\overlay_results\overlay_results\liver\train
    data_dicts = [
        {"image": image_name, "label": label_name}
        for image_name, label_name in zip(train_images, train_labels)
    ]
    train_files, val_files = [data_dicts[0]], [data_dicts[0]]
    print('len(train_files)', len(train_files))

    # set deterministic training for reproducibility
    set_determinism(seed=0)

    # define the data transforms
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"], reader=ITKReader),
            EnsureChannelFirstd(keys=["image", "label"]),
            ResizedC(keys=["image"], spatial_size=(96, 96, 96)),
            ResizedC(keys=["label"], spatial_size=(96, 96, 96)),
#                 Orientationd(keys=["image", "label"], axcodes="PLI"),
#                 AddChanneld(keys=["label"]),
            ScaleIntensityRanged(
                keys=["image"], a_min=-57, a_max=164,
                b_min=0.0, b_max=1.0, clip=True,
            ),
        ]
    )

    train_ds = CacheDataset(
        data=train_files, transform=train_transforms,
        cache_rate=1.0, num_workers=20,
    )
    return train_ds

### train & val loop

In [6]:
def save_checkpoint(state, name):
    file_path = "checkpoints/"
    if not os.path.exists(file_path):
        os.makedirs(file_path)

    epoch = state["epoch"]
    save_dir = file_path + name + str(epoch) + "_" + str(round(float(state["acc"]), 4))
    torch.save(state, save_dir)
    print(f"Saving checkpoint for epoch {epoch} in: {save_dir}")


def save_state_dict(state, name):
    file_path = "checkpoints/"
    if not os.path.exists(file_path):
        os.makedirs(file_path)

    save_dir = file_path + f"{name}_best"
    torch.save(state, save_dir)
    print(f"Best accuracy so far. Saving model to:{save_dir}")


In [7]:
import numpy as np


def train(log_interval, model, device, train_loader, optimizer, epoch):
    # set model as training mode
    model.train()

    losses = []
    scores = []

    # declare the bar for TQDM
    pbar = tqdm(total=len(train_loader), position=0, unit="it", leave=True)

    for batch_idx, batch in enumerate(train_loader):

        print("training started..")
        images, labels = batch["image"].to(device), batch["label"].to(device)
        print("forwarding..")
        optimizer.zero_grad()
        output = model.forward(images)
        print("calculating loss")
        loss = model.loss_function(output, labels)
        loss.backward()
        optimizer.step()

        pbar.set_description(f"\tLoss: {np.mean(losses):.4f}")
        pbar.update(1)

        # show information
        if (batch_idx + 1) % log_interval == 0:
            print(
                f"Train Epoch: {epoch} [{batch_idx + 1}/{len(train_loader)}]\tLoss: {np.mean(losses):.4f}"
            )

    return losses, scores


### run

#### Test Dataloader

In [8]:
train_loader = torch.utils.data.DataLoader(prepare_data(), batch_size=1, shuffle=True)

for batch in train_loader:
    images, labels = batch["image"], batch["label"]
    print(images.shape)
    print(labels.shape)
    break


len(train_files) 1


OptionalImportError: required package `itk` is not installed or the version doesn't match requirement.

#### Set up the training

In [None]:
mmar = {
        RemoteMMARKeys.ID: "clara_pt_liver_and_tumor_ct_segmentation_1",
        RemoteMMARKeys.NAME: "clara_pt_liver_and_tumor_ct_segmentation",
        RemoteMMARKeys.FILE_TYPE: "zip",
        RemoteMMARKeys.HASH_TYPE: "md5",
        RemoteMMARKeys.HASH_VAL: None,
        RemoteMMARKeys.MODEL_FILE: os.path.join("models", "model.pt"),
        RemoteMMARKeys.CONFIG_FILE: os.path.join("config", "config_train.json"),
        RemoteMMARKeys.VERSION: 1,
    }

In [None]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

In [None]:
model =  Net()
img3d = torch.randn(1, 1, 96, 96, 96)
print("Input shape: ", img3d.shape)

preds = model(img3d)
print("ViT3D output size:", preds.shape)

In [None]:
import torch.optim as optim

EPOCHS = 5

batchsize_train = 1
batchsize_eval = 1

log_interval = 2
count = 0
min_loss = 1000
max_acc = 0
patience = 3

lr = 1e-4


# Declare the model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model = Net()
model.to(device)


# Set the optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)


train_dataset = prepare_data()

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batchsize_train, shuffle=True, drop_last=True
)

for epoch in range(1, EPOCHS):
    losses, acces = train(log_interval, model, device, train_loader, optimizer, epoch)
    loss = np.mean(losses)

    # Save the state for the best model.
    if loss < min_loss:
        min_loss = loss
        count = 0
        save_state_dict(model.state_dict(), "testing")
    else:
        count += 1
        print(f"Validation loss Not improved. > {min_loss}")
        print("Count ->", count)

    if epoch % log_interval == 0:
        # Save checkpoint of the model
        save_checkpoint(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
            },
            "testing_limits",
        )
        print(f"Saved checkpoint for epoch {epoch}")

    if count == patience:
        print(f"Early termination, {patience} epochs without improvement.")
        break
