# Pytorch-Lightning으로 기존 데이터로더를 구현하기

### TrainOption 살펴보기 (Daniel Github)

In [None]:
from options.train_options import TrainOptions
import sys

# Mimic the command-line arguments
sys.argv = ['train.py', '--model', 'ourGAN', '--dataroot', 'datasets/IXI', '--name', 'ourGAN_run', 
            '--which_direction', 'BtoA', '--lambda_A', '100', '--batchSize', '16', '--output_nc', '1', 
            '--input_nc', '3', '--gpu_ids', '0', '--niter', '50', '--niter_decay', '50', 
            '--save_epoch_freq', '25', '--lambda_vgg', '100', '--checkpoints_dir', 'checkpoints/', 
            '--training', '--dataset_misalign']

opt = TrainOptions()
opt.initialize()
args = opt.parser.parse_args()

print(args.model)
print(args.dataroot)
print(args.name)
print(args.niter_decay)


In [None]:
import h5py
import numpy as np
import torch
import torchio as tio
from typing import Sequence
from monai.data import ArrayDataset, DataLoader, PersistentDataset
from monai.transforms import Compose, RandAffine, Rand2DElastic, Rand3DElastic
import matplotlib.pyplot as plt
from sklearn.metrics import mutual_info_score
import nibabel as nib

### Dataset Visualization

In [None]:
f = h5py.File('datasets/IXI/train/data.mat','r')

print("data_x shape: ", f['data_x'].shape)
print("data_y shape: ", f['data_y'].shape)

In [None]:
t1 = f['data_x'][:,:,20,0]
aligned_t2 = f['data_y'][:,:,20,0]

blend_and_transpose = lambda x, y, alpha=0.3: np.transpose(blend_images(x[None], y[None], alpha,cmap='hot'), (1, 2, 0))
"""
This lambda function blends two images and transposes the resulting image.

Parameters:
-----------
x : ndarray
    First image to blend. Should be a 2D ndarray.
y : ndarray
    Second image to blend. Should be a 2D ndarray.
alpha : float, optional
    The weight for blending the images. The higher the alpha, the more weight for the second image. Default is 0.3.

Returns:
--------
ndarray
    The blended and transposed image. Should be a 2D ndarray.

Examples:
---------
>>> img1 = np.random.rand(10, 10)
>>> img2 = np.random.rand(10, 10)
>>> blended_img = blend_and_transpose(img1, img2)
"""
blended_align = blend_and_transpose(t1, aligned_t2)

## Misalign (Rigid, Elastic)

In [None]:

monai.utils.set_determinism(10) # Fixes the seed (for reproducibility)

image = aligned_t2[None]

# Define the MONAI RandAffine transform
rand_affine_transform = RandAffine(
    mode="bilinear",
    prob=1.0,
    spatial_size=None,
    rotate_range=(0.2, 0.2),  # Rotation range in radians
    shear_range=(0.1, 0.1),   # Shear range
    translate_range=(5, 5),  # Translation range in pixels
    padding_mode="border",
)

# Define the MONAI Rand2DElastic transform
rand_2d_elastic_transform = Rand2DElastic(
    mode="bilinear",
    prob=1.0,
    spacing=(30, 30), # Distance between control points
    magnitude_range=(0.1, 0.1), # Deformation magnitude
    rotate_range=(0.1, 0.1),
    shear_range=(0.1, 0.1),
    translate_range=(5, 5),
    padding_mode="border",
)

# Apply the transform
misaligned_t2_rigid = rand_affine_transform(image)[0]
misaligned_t2_elastic = rand_2d_elastic_transform(image)[0]

blended_misalign_rigid = blend_and_transpose(t1, misaligned_t2_rigid) # Blended image 
blended_misalign_elastic = blend_and_transpose(t1, misaligned_t2_elastic) # Blended image


In [None]:
def plot_images(images, labels, siz=4, cmap=None):
    """
    This function plots a list of images with corresponding labels.
    
    Parameters:
    -----------
    images : list of ndarray
        List of images. Each image should be a 2D or 3D ndarray.
    labels : list of str
        List of labels. Each label corresponds to an image.
    siz : int, optional
        Size of each image when plotted. Default is 4.
    cmap : str, optional
        Colormap to use for displaying images. If 'gray', the image will be displayed in grayscale. 
        Default is None, in which case the default colormap is used.
        
    Raises:
    -------
    AssertionError
        If the number of images does not match the number of labels.
    
    Examples:
    ---------
    >>> img1 = np.random.rand(10, 10)
    >>> img2 = np.random.rand(10, 10)
    >>> plot_images([img1, img2], ['Image 1', 'Image 2'])
    
    >>> img1 = np.random.rand(10, 10)
    >>> img2 = np.random.rand(10, 10)
    >>> plot_images([img1, img2], ['Image 1', 'Image 2'], cmap='gray')
    """
    assert len(images) == len(labels), "Mismatch in number of images and labels"
    n = len(images)
    
    plt.figure(figsize=(siz*n, siz))  # Adjust figure size based on number of images
    for i in range(n):
        plt.subplot(1, n, i+1)
        plt.imshow(images[i])
        if cmap == 'gray':
            plt.gray()
        plt.title(labels[i])
    plt.show()

In [None]:
plot_images([aligned_t2,misaligned_t2_rigid, misaligned_t2_elastic], ["Aligned", "Misaligned (Rigid)", "Misaligned (Elastic)"],3, cmap='gray')

plot_images([blended_align,blended_misalign_rigid, blended_misalign_elastic], ["Aligned", "Misaligned (Rigid)", "Misaligned (Elastic)"],3)


## Measure of misalignment

1. Mutual Information
2. Cross-Correlation
3. Target Registration Error (TRE)

In [None]:

def calculate_mutual_info(image1, image2):
    """
    This function calculates the mutual information between two images.

    Parameters:
    image1 (np.array): The first image
    image2 (np.array): The second image

    Returns:
    float: The mutual information score
    """
    hist_2d, _, _ = np.histogram2d(image1.ravel(), image2.ravel(), bins=20)
    return mutual_info_score(None, None, contingency=hist_2d)

mi_align = calculate_mutual_info(aligned_t2, aligned_t2)
mi_misalign = calculate_mutual_info(aligned_t2, misaligned_t2_rigid)
mi_misalign2 = calculate_mutual_info(aligned_t2, misaligned_t2_elastic)

print(f"Mutual Information: {mi_align} -> {mi_misalign}")
print(f"Mutual Information: {mi_align} -> {mi_misalign2}")

print((mi_align - mi_misalign) / mi_align)


## Loop through misalignment process for each slice (dataA -> dataB)

- `PersistentDataset` processes original data sources through the non-random transforms on first use, and stores these intermediate tensor values to an on-disk persistence representation.
- The intermediate processed tensors are loaded from disk on each use for processing by the random-transforms for each analysis request.
- The `PersistentDataset` has a similar memory footprint to the simple Dataset, with performance characteristics close to the CacheDataset at the expense of disk storage.
- Additionally, the cost of first time processing of data is distributed across each first use.

In [None]:
print("data_x shape: ", f['data_x'].shape)
print("data_y shape: ", f['data_y'].shape)

In [None]:
def slice_array(array, block_size):
    """
    Slice a 3D NumPy array into smaller blocks along the third dimension.

    Args:
        array (numpy.ndarray): Input array to be sliced. It should have 3 dimensions.
        block_size (int): Size of each block along the third dimension.

    Returns:
        numpy.ndarray: Sliced array with shape (array.shape[0], array.shape[1], block_size, num_slices),
            where num_slices is array.shape[2] divided by block_size.

    Raises:
        AssertionError: If the input array does not have 3 dimensions or the block size does not evenly divide
            the shape of the input array along the third dimension.

    """
    assert len(array.shape) == 3, "Input array should have 3 dimensions."
    assert array.shape[2] % block_size == 0, "Block size should evenly divide the array shape."

    num_slices = array.shape[2] // block_size
    sliced_array = np.zeros((array.shape[0], array.shape[1], block_size, num_slices))

    for i in range(num_slices):
        start_idx = i * block_size
        end_idx = (i + 1) * block_size
        sliced_array[:, :, :, i] = array[:, :, start_idx:end_idx]

    return sliced_array

def save_slices_to_nii(sliced_array, output_prefix):
    """
    Save the individual slices of a 4D NumPy array as NIfTI files.

    Args:
        sliced_array (numpy.ndarray): Input array containing the slices to be saved. It should have 4 dimensions.
        output_prefix (str): Prefix to be used for the output file names.

    Raises:
        AssertionError: If the input array does not have 4 dimensions.

    """
    assert len(sliced_array.shape) == 4, "Input array should have 4 dimensions."

    for i in range(sliced_array.shape[3]):
        data = sliced_array[:, :, :, i]
        nifti_img = nib.Nifti1Image(data, affine=np.eye(4))
        output_filename = f"{output_prefix}_{i+1}.nii.gz"
        nib.save(nifti_img, output_filename)

# Load File
f = h5py.File('datasets/IXI/train/data.mat','r')

# Example usage
array = f['data_x'][...,0]
block_size = 91
sliced_array = slice_array(array, block_size)
output_prefix = "datasets/IXI/train/t1"
save_slices_to_nii(sliced_array, output_prefix)

array = f['data_y'][...,0]
block_size = 91
sliced_array = slice_array(array, block_size)
output_prefix = "datasets/IXI/train/t2"
save_slices_to_nii(sliced_array, output_prefix)



In [None]:
rand_affine_transform = RandAffine(
    mode="bilinear",
    prob=1.0,
    spatial_size=None,
    rotate_range=(0.2, 0.2, 0.2),  # Rotation range in radians
    shear_range=(0.1, 0.1, 0.1),   # Shear range
    translate_range=(5, 5, 5),  # Translation range in pixels
    padding_mode="border",
)

rand_3d_elastic_transform = Rand3DElastic(
    mode="nearest",
    prob=1.0,
    sigma_range=(5, 8), # Sigma range for smoothing random displacement
    magnitude_range=(0.1, 0.1, 0.1), # Deformation magnitude
    rotate_range=(0.1, 0.1, 0.1),
    shear_range=(0.05, 0.05, 0.05),
    translate_range=(3, 3, 3),
    padding_mode="border",
)




In [None]:
import glob, os
from monai.transforms import (
    Compose,
    LoadImage,
    RandSpatialCrop,
    ScaleIntensity,
    EnsureType,
)
from monai.utils import first

root_dir = 'datasets/IXI/train'
t1s = sorted(glob.glob(os.path.join(root_dir, "t1*.nii.gz")))
t2s = sorted(glob.glob(os.path.join(root_dir, "t2*.nii.gz")))

imtrans = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
    ]
)

segtrans = Compose(
    [
        LoadImage(image_only=True, ensure_channel_first=True),
        rand_3d_elastic_transform
    ]
)

ds = ArrayDataset(t1s, imtrans, t2s, segtrans)

loader = torch.utils.data.DataLoader(ds, batch_size=1, num_workers=2, pin_memory=torch.cuda.is_available())
im1, im2 = first(loader)
print(im1.shape, im2.shape)

In [None]:
matshow3d(im2[0,0,:,:,20:26], frame_dim=-1)

In [None]:
from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import Any

import numpy as np
from torch.utils.data import Dataset

from monai.config import DtypeLike
from monai.data import ImageDataset, Dataset
from monai.data.image_reader import ImageReader
from monai.utils import MAX_SEED, get_seed


In [None]:
img_dataset = ImageDataset(
    image_files=img_list,
    seg_files=seg_list,
    transform=img_xform,
    seg_transform=seg_xform,
    image_only=False,
    transform_with_metadata=True,
)

In [None]:
from monai.transforms import (
    EnsureChannelFirstd,
    AsDiscrete,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
)

def transformations():
    train_transforms = Compose(
        [
            # LoadImaged with image_only=True is to return the MetaTensors
            # the additional metadata dictionary is not returned.
            LoadImaged(keys=["image", "label"], image_only=True),
            EnsureChannelFirstd(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.5, 1.5, 2.0),
                mode=("bilinear", "nearest"),
            ),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            # randomly crop out patch samples from big
            # image based on pos / neg ratio
            # the image centers of negative samples
            # must be in valid image area
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(96, 96, 96),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
        ]
    )

    # NOTE: No random cropping in the validation data,
    # we will evaluate the entire image using a sliding window.
    val_transforms = Compose(
        [
            # LoadImaged with image_only=True is to return the MetaTensors
            # the additional metadata dictionary is not returned.
            LoadImaged(keys=["image", "label"], image_only=True),
            EnsureChannelFirstd(keys=["image", "label"]),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(1.5, 1.5, 2.0),
                mode=("bilinear", "nearest"),
            ),
            ScaleIntensityRanged(
                keys=["image"],
                a_min=-57,
                a_max=164,
                b_min=0.0,
                b_max=1.0,
                clip=True,
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
        ]
    )
    return train_transforms, val_transforms

In [None]:
from monai.transforms import Compose, ToTensor, Resize

# Define the MONAI RandAffine transform
rand_affine_transform = RandAffine(
    mode="bilinear",
    prob=1.0,
    spatial_size=None,
    rotate_range=(0.2, 0.2),  # Rotation range in radians
    shear_range=(0.1, 0.1),   # Shear range
    translate_range=(5, 5),  # Translation range in pixels
    padding_mode="border",
)

# Define the MONAI Rand2DElastic transform
rand_2d_elastic_transform = Rand2DElastic(
    mode="bilinear",
    prob=1.0,
    spacing=(30, 30), # Distance between control points
    magnitude_range=(0.1, 0.1), # Deformation magnitude
    rotate_range=(0.1, 0.1),
    shear_range=(0.1, 0.1),
    translate_range=(5, 5),
    padding_mode="border",
)

transform = Compose([ToTensor()])




In [None]:
# Create a dataset with persistent storage
dataset = PersistentDataset(
    data=[{"image": torch.randn(256, 256), "label": torch.randint(0, 2, (256, 256))} for _ in range(1000)],
    cache_dir="./cache",  # Directory to store the dataset cache
    refresh=False  # Whether to refresh the cache if it already exists
)

In [None]:
# data_new_1 : t1, t2_align, t2_misalign, mutual_info

# data_new_2 : t1, t2_align, t2_misalign, mutual_info

# data_new_3 : t1, t2_align, t2_misalign, mutual_info

# data_new_4 : t1, t2_align, t2_misalign, mutual_info

# data_new_5 : t1, t2_align, t2_misalign, mutual_info

In [None]:
from monai.visualize import matshow3d, blend_images

matshow3d(np.concatenate((f['data_x'][:,:,20:24,0],f['data_y'][:,:,20:24,0]),1), frame_dim=-1, show=True, cmap='gray')

In [None]:
from monai.networks.nets import Generator, Discriminator


In [None]:
# Base Network
from monai.networks.nets import AttentionUnet

ngf = 24
net = AttentionUnet(spatial_dims=2, in_channels=1, out_channels=1, channels=(ngf, ngf*2, ngf*4, ngf*8), strides=[1, 1, 1, 1])

inp = torch.randn(1, 1, 256, 256)
out = net(inp)
print(out.shape)


In [None]:
import pytorch_lightning as pl


In [None]:
import torch
import torch.nn as nn
from monai.networks.nets import AttentionUnet, Discriminator

class CycleGAN(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, ngf=32, ndf=64):
        super(CycleGAN, self).__init__()

        # Generators
        self.gen_AtoB = AttentionUnet(spatial_dims=2, in_channels, out_channels, channels=(ngf, ngf*2, ngf*4, ngf*8), strides=[1])
        self.gen_BtoA = AttentionUnet(spatial_dims=2, in_channels, out_channels, channels=(ngf, ngf*2, ngf*4, ngf*8), strides=[1])

        # Discriminators
        self.dis_A = Discriminator(in_channels, channels=(ndf, ndf*2, ndf*4, ndf*8))
        self.dis_B = Discriminator(in_channels, channels=(ndf, ndf*2, ndf*4, ndf*8))

    def forward(self, real_A, real_B):
        fake_B = self.gen_AtoB(real_A)
        cycle_A = self.gen_BtoA(fake_B)

        fake_A = self.gen_BtoA(real_B)
        cycle_B = self.gen_AtoB(fake_A)

        # Discriminator outputs
        real_A_dis_out = self.dis_A(real_A)
        fake_A_dis_out = self.dis_A(fake_A)

        real_B_dis_out = self.dis_B(real_B)
        fake_B_dis_out = self.dis_B(fake_B)

        return fake_A, fake_B, cycle_A, cycle_B, real_A_dis_out, real_B_dis_out, fake_A_dis_out, fake_B_dis_out


In [None]:
net = CycleGAN()