In this notebook I adapt a 3D unet meant for segmentation and use it for image generation using a diffusion process.
Once reading, will be run on the cluster with varying arguments (grid search).

Run:  
papermill input_notebook.ipynb output_notebook.ipynb -p arg1 value1 -p arg2 value2

3D-UNET from here:  
https://github.com/mobarakol/3D_Attention_UNet/tree/main  
Works on input. Currently input size is 32x64x64 (batch size 4) - smaller than that doesn't work with this network. Will have to reduce network size..  

Paper:  
Islam, M., Vibashan, V. S., Jose, V. J. M., Wijethilake, N., Utkarsh, U., & Ren, H. (2020). Brain tumor segmentation and survival prediction using 3D attention UNet. In Brainlesion: Glioma, Multiple Sclerosis, Stroke and Traumatic Brain Injuries: 5th International Workshop, BrainLes 2019, Held in Conjunction with MICCAI 2019, Shenzhen, China, October 17, 2019, Revised Selected Papers, Part I 5 (pp. 262-272). Springer International Publishing.

Implementing diffusion from here:  
https://github.com/huggingface/diffusion-models-class/blob/main/unit1/01_introduction_to_diffusers.ipynb

Current TODO:
* Fine tunning
* Pick better training images!!
* Is set seed possible? (for reproducibility)
* Look into model complexity, maybe model is too complex for the data?
* Deal with padding - current output has a "frame".
  
Fine-tuning Considerations:
* Current attention - spatial+channel - need to ensure that it doesn't overly focus on small details and ignore broader features.
* create_conv order parameter - can change the relu type and add normalization. num_groups can be changed too if used. Also, batchnorm seems to be possible either before or after conv (adding b to string).
* Activation Order: create_conv can include batch normalization, group normalization and so on - depending on the task, the optimal order of operations might differ.
* Regularization - dropouts, weight decay..
* Final activation - currently input is normalized [0,1] Probably want sigmoid on the output.
* Input size and depth of U-Net.
* Fine-tune augmentation.
* Early Stopping? Im saving weights every 5 epochs anyway..
* Lower learning rate? or learning rate annealing?

Comments:
* torchsummary is installed via pip with the current github version, not the pip install default version as summary_string is available.

In [1]:
import os
import torch
import matplotlib.pyplot as plt
import torchvision
import logging

import sys
import tifffile as tif

import numpy as np
np.set_printoptions(formatter={'float': '{: 0.2f}'.format})

In [2]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
#logging.info("Starting")

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  return torch._C._cuda_getDeviceCount() > 0


## Set default values for arguments:
arguments are set when the notebook is called.

In [32]:
arg_input_dir = '32x64x64'
arg_batch_size = 4 # looks like 4 might be the maximum. WHY??
arg_n_epochs = 70

In [5]:
output_dir = f'{arg_input_dir}_batchSize{arg_batch_size}'
output_path = os.path.join('results',output_dir)
#os.makedirs(output_path)

In [6]:
is_plot = False
if not is_plot:
    import matplotlib
    matplotlib.use('Agg')

## Define Data Loader

In [7]:
from glob import glob
import torchio as tio
from torchio import RandomBlur, RandomElasticDeformation, RandomFlip, RescaleIntensity
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
image_paths = glob(os.path.join('data', 'training_set_20230816', arg_input_dir, '*.tif'))

logging.info(f'Number of training images: {len(image_paths)}')

INFO:root:Number of training images: 684


In [9]:
## going with pretty conservative transform values for now:
transforms = tio.Compose([
    RescaleIntensity(out_min_max=(0, 1)),
    RandomElasticDeformation(num_control_points=7, max_displacement=3),
    RandomFlip(axes=(0, 1, 2)),
    RandomBlur(std=(0.5)) ## (0.5, 1.0))
])

In [10]:
subjects = []
for image_path in image_paths:
    subject = tio.Subject(
        image=tio.ScalarImage(image_path),
    )
    subjects.append(subject)

dataset = tio.SubjectsDataset(subjects, transform=transforms)

In [11]:
dataloader = DataLoader(dataset, batch_size=arg_batch_size, shuffle=True)

In [12]:
# going through my data once (while introducing transformations)
for batch in dataloader:
    images = batch['image']
    break ## just to do one iteration

In [13]:
logging.info(f'Model input size: {images["data"].shape}')

INFO:root:Model input size: torch.Size([4, 1, 64, 64, 32])


## Define curreption

In [14]:
from diffusers import DDPMScheduler

INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpy9hfs9sl
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpy9hfs9sl/_remote_module_non_scriptable.py


In [15]:
# # going through my data once (while introducing transformations)
# for batch in dataloader:
#     images = batch['image']
#     break

In [16]:
def corrupt(x, amount):
    """Corrupt the input `x` by mixing it with noise according to `amount`"""
    noise = torch.rand_like(x)
    #noise = torch.randn_like(x) * x.std() + x.mean()
    amount = amount.view(-1, 1, 1, 1, 1) # Add an extra dimension for 3D data
    return x*(1-amount) + noise*amount 


In [17]:
# Adding noise
amount = torch.linspace(0, 1, images['data'].shape[0]) # Left to right -> more corruption
noised_x = corrupt(images['data'], amount)

In [18]:
slice_index = 4
x_slice = images['data'][:, :, :, :, slice_index]
x_noised_slice = noised_x[:,:,:,:,slice_index]

In [19]:
if is_plot:
    # Plotting the input data
    fig, axs = plt.subplots(2, 1, figsize=(12, 7))
    axs[0].set_title('Input data')
    axs[0].imshow(torchvision.utils.make_grid(x_slice)[0], cmap='Greys')
    
    # Plotting the noised version
    axs[1].set_title('Corrupted data (-- amount increases -->)')
    axs[1].imshow(torchvision.utils.make_grid(x_noised_slice)[0], cmap='Greys');

In [20]:
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

In [21]:
timesteps = torch.linspace(0, 999, 4).long().to(device)
noise = torch.randn_like(images["data"])
noisy_x = noise_scheduler.add_noise(images["data"], noise, timesteps)
#print("Noisy X shape", noisy_x.shape)

In [22]:
if is_plot:
    fig, ax = plt.subplots(1, 2, figsize=(10,10))
    ax[0].imshow(noisy_x[0,0,:,:,slice_index])
    ax[0].axis('off')
    ax[1].imshow(noisy_x[3,0,:,:,slice_index])
    ax[1].axis('off')

## Define NN:

In [24]:
import torch.nn as nn
from torch.nn import functional as F
from torchsummary import summary, summary_string

### Building blocks for the NN:

In [25]:
# 3D version of the SCA (Spatial and Channel Attention) layer - 
# Its an attention mechanism that takes into account both spatial and channel-wise relationships within the input data
class SCA3D(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.channel_excitation = nn.Sequential(nn.Linear(channel, int(channel // reduction)),
                                                nn.ReLU(inplace=True),
                                                nn.Linear(int(channel // reduction), channel))
        self.spatial_se = nn.Conv3d(channel, 1, kernel_size=1,
                                    stride=1, padding=0, bias=False)

    def forward(self, x):
        bahs, chs, _, _, _ = x.size() ## bahs=batch_size
        chn_se = self.avg_pool(x).view(bahs, chs)
        chn_se = torch.sigmoid(self.channel_excitation(chn_se).view(bahs, chs, 1, 1,1))
        chn_se = torch.mul(x, chn_se)
        spa_se = torch.sigmoid(self.spatial_se(x))
        spa_se = torch.mul(x, spa_se)
        net_out = spa_se + x + chn_se
        return net_out


def conv3d(in_channels, out_channels, kernel_size, bias, padding=1):
    return nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)


def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=1):
    """
    Create a list of modules that together constitute a single conv layer with non-linearity
    and optional batchnorm/groupnorm.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        order (string): order of things, e.g.
            'cr' -> conv + ReLU
            'crg' -> conv + ReLU + groupnorm
            'cl' -> conv + LeakyReLU
            'ce' -> conv + ELU
        num_groups (int): number of groups for the GroupNorm
        padding (int): add zero-padding to the input
    Return:
        list of tuple (name, module)
    """
    assert 'c' in order, "Conv layer MUST be present"
    assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'

    modules = []
    for i, char in enumerate(order):
        if char == 'r':
            modules.append(('ReLU', nn.ReLU(inplace=True)))
        elif char == 'l':
            modules.append(('LeakyReLU', nn.LeakyReLU(negative_slope=0.1, inplace=True)))
        elif char == 'e':
            modules.append(('ELU', nn.ELU(inplace=True)))
        elif char == 'c':
            # add learnable bias only in the absence of gatchnorm/groupnorm
            bias = not ('g' in order or 'b' in order)
            modules.append(('conv', conv3d(in_channels, out_channels, kernel_size, bias, padding=padding)))
        elif char == 'g':
            is_before_conv = i < order.index('c')
            assert not is_before_conv, 'GroupNorm MUST go after the Conv3d'
            # number of groups must be less or equal the number of channels
            if out_channels < num_groups:
                num_groups = out_channels
            modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=out_channels)))
        elif char == 'b':
            is_before_conv = i < order.index('c')
            if is_before_conv:
                modules.append(('batchnorm', nn.BatchNorm3d(in_channels)))
            else:
                modules.append(('batchnorm', nn.BatchNorm3d(out_channels)))
        else:
            raise ValueError("Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c']")

    return modules


class SingleConv(nn.Sequential):
    """
    Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
    of operations can be specified via the `order` parameter
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        kernel_size (int): size of the convolving kernel
        order (string): determines the order of layers, e.g.
            'cr' -> conv + ReLU
            'crg' -> conv + ReLU + groupnorm
            'cl' -> conv + LeakyReLU
            'ce' -> conv + ELU
        num_groups (int): number of groups for the GroupNorm
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8, padding=1):
        super(SingleConv, self).__init__()

        for name, module in create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding=padding):
            self.add_module(name, module)


class DoubleConv(nn.Sequential):
    """
    A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
    We use (Conv3d+ReLU+GroupNorm3d) by default.
    This can be changed however by providing the 'order' argument, e.g. in order
    to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
    Use padded convolutions to make sure that the output (H_out, W_out) is the same
    as (H_in, W_in), so that you don't have to crop in the decoder path.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
        kernel_size (int): size of the convolving kernel
        order (string): determines the order of layers, e.g.
            'cr' -> conv + ReLU
            'crg' -> conv + ReLU + groupnorm
            'cl' -> conv + LeakyReLU
            'ce' -> conv + ELU
        num_groups (int): number of groups for the GroupNorm
    """

    def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='crg', num_groups=8):
        super(DoubleConv, self).__init__()
        if encoder:
            # we're in the encoder path
            conv1_in_channels = in_channels
            conv1_out_channels = out_channels // 2
            if conv1_out_channels < in_channels:
                conv1_out_channels = in_channels
            conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
        else:
            # we're in the decoder path, decrease the number of channels in the 1st convolution
            conv1_in_channels, conv1_out_channels = in_channels, out_channels
            conv2_in_channels, conv2_out_channels = out_channels, out_channels

        # conv1
        self.add_module('SingleConv1',
                        SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups))
        # conv2
        self.add_module('SingleConv2',
                        SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups))


class ExtResNetBlock(nn.Module):
    """
    Basic UNet block consisting of a SingleConv followed by the residual block.
    The SingleConv takes care of increasing/decreasing the number of channels and also ensures that the number
    of output channels is compatible with the residual block that follows.
    This block can be used instead of standard DoubleConv in the Encoder module.
    Motivated by: https://arxiv.org/pdf/1706.00120.pdf
    Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, **kwargs):
        super(ExtResNetBlock, self).__init__()

        # first convolution
        self.conv1 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
        # residual block
        self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups)
        # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
        n_order = order
        for c in 'rel':
            n_order = n_order.replace(c, '')
        self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
                                num_groups=num_groups)

        # create non-linearity separately
        if 'l' in order:
            self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        elif 'e' in order:
            self.non_linearity = nn.ELU(inplace=True)
        else:
            self.non_linearity = nn.ReLU(inplace=True)

    def forward(self, x):
        # apply first convolution and save the output as a residual
        out = self.conv1(x)
        residual = out

        # residual block
        out = self.conv2(out)
        out = self.conv3(out)

        out += residual
        out = self.non_linearity(out)

        return out


class Encoder(nn.Module):
    """
    A single module from the encoder path consisting of the optional max
    pooling layer (one may specify the MaxPool kernel_size to be different
    than the standard (2,2,2), e.g. if the volumetric data is anisotropic
    (make sure to use complementary scale_factor in the decoder path) followed by
    a DoubleConv module.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        conv_kernel_size (int): size of the convolving kernel
        apply_pooling (bool): if True use MaxPool3d before DoubleConv
        pool_kernel_size (tuple): the size of the window to take a max over
        pool_type (str): pooling layer: 'max' or 'avg'
        basic_module(nn.Module): either ResNetBlock or DoubleConv
        conv_layer_order (string): determines the order of layers
            in `DoubleConv` module. See `DoubleConv` for more info.
        num_groups (int): number of groups for the GroupNorm
    """

    def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
                 pool_kernel_size=(2, 2, 2), pool_type='max', basic_module=DoubleConv, conv_layer_order='crg',
                 num_groups=8):
        super(Encoder, self).__init__()
        assert pool_type in ['max', 'avg']
        if apply_pooling:
            if pool_type == 'max':
                self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
            else:
                self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
        else:
            self.pooling = None

        self.basic_module = basic_module(in_channels, out_channels,
                                         encoder=True,
                                         kernel_size=conv_kernel_size,
                                         order=conv_layer_order,
                                         num_groups=num_groups)

    def forward(self, x):
        if self.pooling is not None:
            x = self.pooling(x)
        #x = self.scse(x)
        x = self.basic_module(x)
        return x


class Decoder(nn.Module):
    """
    A single module for decoder path consisting of the upsample layer
    (either learned ConvTranspose3d or interpolation) followed by a DoubleConv
    module.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        kernel_size (int): size of the convolving kernel
        scale_factor (tuple): used as the multiplier for the image H/W/D in
            case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
            from the corresponding encoder
        basic_module(nn.Module): either ResNetBlock or DoubleConv
        conv_layer_order (string): determines the order of layers
            in `DoubleConv` module. See `DoubleConv` for more info.
        num_groups (int): number of groups for the GroupNorm
    """

    def __init__(self, in_channels, out_channels, kernel_size=3,
                 scale_factor=(2, 2, 2), basic_module=DoubleConv, conv_layer_order='crg', num_groups=8):
        super(Decoder, self).__init__()
        if basic_module == DoubleConv:
            # if DoubleConv is the basic_module use nearest neighbor interpolation for upsampling
            self.upsample = None
        else:
            # otherwise use ConvTranspose3d (bear in mind your GPU memory)
            # make sure that the output size reverses the MaxPool3d from the corresponding encoder
            # (D_out = (D_in − 1) ×  stride[0] − 2 ×  padding[0] +  kernel_size[0] +  output_padding[0])
            # also scale the number of channels from in_channels to out_channels so that summation joining
            # works correctly
            self.upsample = nn.ConvTranspose3d(in_channels,
                                               out_channels,
                                               kernel_size=kernel_size,
                                               stride=scale_factor,
                                               padding=1,
                                               output_padding=1)
            # adapt the number of in_channels for the ExtResNetBlock
            in_channels = out_channels

        self.scse = SCA3D(in_channels)

        self.basic_module = basic_module(in_channels, out_channels,
                                         encoder=False,
                                         kernel_size=kernel_size,
                                         order=conv_layer_order,
                                         num_groups=num_groups)

    def forward(self, encoder_features, x):
        if self.upsample is None:
            # use nearest neighbor interpolation and concatenation joining
            output_size = encoder_features.size()[2:]
            x = F.interpolate(x, size=output_size, mode='nearest')
            # concatenate encoder_features (encoder path) with the upsampled input across channel dimension
            x = torch.cat((encoder_features, x), dim=1)
        else:
            # use ConvTranspose3d and summation joining
            x = self.upsample(x)
            x += encoder_features
        x = self.scse(x)
        x = self.basic_module(x)
        return x


class FinalConv(nn.Sequential):
    """
    A module consisting of a convolution layer (e.g. Conv3d+ReLU+GroupNorm3d) and the final 1x1 convolution
    which reduces the number of channels to 'out_channels'.
    with the number of output channels 'out_channels // 2' and 'out_channels' respectively.
    We use (Conv3d+ReLU+GroupNorm3d) by default.
    This can be change however by providing the 'order' argument, e.g. in order
    to change to Conv3d+BatchNorm3d+ReLU use order='cbr'.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output channels
        kernel_size (int): size of the convolving kernel
        order (string): determines the order of layers, e.g.
            'cr' -> conv + ReLU
            'crg' -> conv + ReLU + groupnorm
        num_groups (int): number of groups for the GroupNorm
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, order='crg', num_groups=8):
        super(FinalConv, self).__init__()

        # conv1
        self.add_module('SingleConv', SingleConv(in_channels, in_channels, kernel_size, order, num_groups))

        # in the last layer a 1×1 convolution reduces the number of output channels to out_channels
        final_conv = nn.Conv3d(in_channels, out_channels, 1)
        self.add_module('final_conv', final_conv)

### Define unet:

In [26]:
def create_feature_maps(init_channel_number, number_of_fmaps):
    return [init_channel_number * 2 ** k for k in range(number_of_fmaps)]

class UNet3D(nn.Module):
    """
    3DUnet model from
    `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
        <https://arxiv.org/pdf/1606.06650.pdf>`.
    Args:
        in_channels (int): number of input channels
        out_channels (int): number of output segmentation masks;
            Note that that the of out_channels might correspond to either
            different semantic classes or to different binary segmentation mask.
            It's up to the user of the class to interpret the out_channels and
            use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
            or BCEWithLogitsLoss (two-class) respectively)
        f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
            of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
        final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the
            final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used
            to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model.
        layer_order (string): determines the order of layers
            in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d.
            See `SingleConv` for more info
        init_channel_number (int): number of feature maps in the first conv layer of the encoder; default: 64
        num_groups (int): number of groups for the GroupNorm
    """

    def __init__(self, in_channels, out_channels, final_sigmoid, f_maps=16, layer_order='crg', num_groups=8,
                 **kwargs):
        super(UNet3D, self).__init__()

        if isinstance(f_maps, int):
            # use 4 levels in the encoder path as suggested in the paper
            f_maps = create_feature_maps(f_maps, number_of_fmaps=6)

        # create encoder path consisting of Encoder modules. The length of the encoder is equal to `len(f_maps)`
        # uses DoubleConv as a basic_module for the Encoder
        encoders = []
        for i, out_feature_num in enumerate(f_maps):
            if i == 0:
                encoder = Encoder(in_channels, out_feature_num, apply_pooling=False, basic_module=DoubleConv,
                                  conv_layer_order=layer_order, num_groups=num_groups)
            else:
                encoder = Encoder(f_maps[i - 1], out_feature_num, basic_module=DoubleConv,
                                  conv_layer_order=layer_order, num_groups=num_groups)
            encoders.append(encoder)

        self.encoders = nn.ModuleList(encoders)

        # create decoder path consisting of the Decoder modules. The length of the decoder is equal to `len(f_maps) - 1`
        # uses DoubleConv as a basic_module for the Decoder
        decoders = []
        reversed_f_maps = list(reversed(f_maps))
        for i in range(len(reversed_f_maps) - 1):
            in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
            out_feature_num = reversed_f_maps[i + 1]
            decoder = Decoder(in_feature_num, out_feature_num, basic_module=DoubleConv,
                              conv_layer_order=layer_order, num_groups=num_groups)
            decoders.append(decoder)

        self.decoders = nn.ModuleList(decoders)

        # in the last layer a 1×1 convolution reduces the number of output
        # channels to the number of labels
        self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
        self.avg_pool = nn.AdaptiveAvgPool3d(1)

        if final_sigmoid:
            self.final_activation = nn.Sigmoid()
        else:
            self.final_activation = nn.Softmax(dim=1)

    def forward(self, x):
        # encoder part
        encoders_features = []
        for encoder in self.encoders:
            x = encoder(x)
            # reverse the encoder outputs to be aligned with the decoder
            encoders_features.insert(0, x)

        # remove the last encoder's output from the list
        # !!remember: it's the 1st in the list
        pool_fea = self.avg_pool(encoders_features[0]).squeeze(0).squeeze(1).squeeze(1).squeeze(1)
        encoders_features = encoders_features[1:]
        # decoder part
        for decoder, encoder_features in zip(self.decoders, encoders_features):
            # pass the output from the corresponding encoder and the output
            # of the previous decoder
            x = decoder(encoder_features, x)

        x = self.final_conv(x)

        # apply final_activation (i.e. Sigmoid or Softmax) only for prediction. During training the network outputs
        # logits and it's up to the user to normalize it before visualising with tensorboard or computing validation metric
        if not self.training:
            x = self.final_activation(x)

        return x, pool_fea

## Test the NN

In [27]:
def test_unet(model, dataloader):
    for batch in dataloader:
        images = batch['image']['data']  # Extracting the data tensor
        bs, C, H, W, D = images.shape
        #images = images.to(device)  # Moving to the proper device if necessary (e.g., GPU)


        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=images.device
        ).long()

        # Expand timesteps to have the same spatial dimensions as images
        timesteps = timesteps.view(bs, 1, 1, 1, 1)
        timesteps = timesteps.repeat(1, 1, H, W, D).float()

        images_with_timesteps = torch.cat([images, timesteps], dim=1)
        
        # Forward pass through the model
        output, _ = model(images_with_timesteps.to(device))

        logging.info(f'Input shape: {images_with_timesteps.shape}')
        logging.info(f'Output shape: {output.shape}')
        break  # Stop after the first batch

# Initialize the UNet3D model with 1 input channel and 1 output channel
unet_model = UNet3D(in_channels=2, out_channels=1, final_sigmoid=True).to(device)

# Test the UNet model
test_unet(unet_model, dataloader)



  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
INFO:root:Input shape: torch.Size([4, 2, 64, 64, 32])
INFO:root:Output shape: torch.Size([4, 1, 64, 64, 32])


In [31]:
for batch in dataloader:
    images = batch['image']['data']
    bs, C, H, W, D = images.shape
    break ## just to do one iteration

print(C, H, W, D)
    
#summary(unet_model, (2, H, W, D)) ## print 
model_summary_str, _ = summary_string(unet_model, 
                                      (2, H, W, D), 
                                      batch_size=-1, 
                                      device=device)#torch.device('cuda:0')

with open(os.path.join(output_path,'model_summary.txt'), 'w') as f:
    f.write(model_summary_str)

1 64 64 32


## Create a Training Loop

In [None]:
# Set the noise scheduler
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2"
)

# Training loop
optimizer = torch.optim.AdamW(unet_model.parameters(), lr=4e-4)

losses = []

for epoch in range(arg_n_epochs):
    for step, batch in enumerate(dataloader):
    
        clean_images = batch["image"]['data'].to(device)
        # Sample noise to add to the images
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bs, C, H, W, D = clean_images.shape

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
        
        timesteps = timesteps.view(bs, 1, 1, 1, 1)
        timesteps = timesteps.repeat(1, 1, H, W, D).float()

        noisy_images_with_timesteps = torch.cat([noisy_images, timesteps], dim=1)

        # Get the model prediction
        noise_pred, _ = unet_model(noisy_images_with_timesteps)

        # Calculate the loss
        loss = F.mse_loss(noise_pred, noise)
        loss.backward(loss)
        losses.append(loss.item())

        # Update the model parameters with the optimizer
        optimizer.step()
        optimizer.zero_grad()

    if (epoch + 1) % 5 == 0:
        loss_last_epoch = sum(losses[-len(dataloader) :]) / len(dataloader)
        logging.info(f'Epoch:{epoch+1}, loss: {loss_last_epoch}')

        torch.save(unet_model.state_dict(), 
                   os.path.join(output_path, f'model_weights_epoch_{epoch+1}.pth'))


In [None]:
n_batches = len(dataloader)

with open(os.path.join(output_path,'losses.txt'), 'w') as f:
    for i,l in enumerate(losses):
        if i%n_batches==0:
            f.write(f'EPOCH {i//n_batches}\n')
        f.write(f'{str(l)}\n')

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))

if is_plot:   
    plt.show()
else:
    plt.savefig(os.path.join(output_path,'losses.png'))

## Generate images:

In [None]:
# Random starting point (8 random images):
sample = torch.randn(bs, C, H, W, D).to(device)

for i, t in enumerate(noise_scheduler.timesteps):

    # Extend timesteps similar to what's done in the training loop
    timesteps = torch.full((bs, 1, H, W, D), t, dtype=torch.float32, device=device)

    # Concatenate sample with timestep
    sample_with_timesteps = torch.cat([sample, timesteps], dim=1)

    # Get model pred
    with torch.no_grad():
        residual, _ = unet_model(sample_with_timesteps)

    # Update sample with step
    sample = noise_scheduler.step(residual, t, sample).prev_sample


In [None]:
if is_plot:
    slice_index = 4
    x_slice = sample[:, :, :, :, slice_index].cpu()
    
    fig, axs = plt.subplots(1, 1, figsize=(12, 7))
    axs.set_title('Gerentaed Images')
    axs.imshow(torchvision.utils.make_grid(x_slice)[0], cmap='Greys')

In [None]:
sample_np = sample.cpu().numpy()

# Save each image in the batch
for idx, img in enumerate(sample_np):
    img_t = np.transpose(img[0], (2, 0, 1))
    tif.imsave(os.path.join(output_path, f'image_{idx}.tif'), img_t)