In [4]:
import os
import sys
sys.path.append('.')
import time
import shutil
import wandb
import logging
import imgaug as ia
import numpy as np
from PIL import Image
from imgaug import augmenters as iaa
from copy import deepcopy

import torch
import torchvision
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Subset
from torch.nn import CrossEntropyLoss

from util import util
from UNet3D.config import (
    TRAINING_EPOCH, NUM_CLASSES, IN_CHANNELS, BCE_WEIGHTS, BACKGROUND_AS_CLASS, TRAIN_CUDA
)
from UNet3D.unet3d import UNet3D
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from models.networks import arch_parameters
from transforms import fake_transform
from util.util import zero_division

from betty.engine import Engine
from betty.configs import Config, EngineConfig
from betty.problems import ImplicitProblem

In [13]:
from data.base_dataset import BaseDataset, get_params_3d, get_transform_torchio
import os
import torchio
import torch
import pandas as pd
from util.util import error, nifti_to_np


class VestibularDataset(BaseDataset):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.add_argument('--csv_path', type=str, required=True, help='Path to the CSV file containing image, mask, and class information.')
        parser.set_defaults(input_nc=1, output_nc=1) # Assuming single channel for MRI images and masks
        return parser

    def __init__(self, opt):
        BaseDataset.__init__(self, opt)
        self.csv_path = opt.csv_path
        self.data_frame = pd.read_csv(self.csv_path)

        self.image_paths = self.data_frame['Image'].tolist()
        self.mask_paths = self.data_frame['Mask'].tolist()
        self.classes = self.data_frame['Class'].tolist()

        # Map 'decreasing' to 0 and 'increasing' to 1
        self.class_mapping = {'decreasing': 0, 'increasing': 1}

        self.original_shape = None
        self.affine = None

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        mask_path = self.mask_paths[index]
        class_label_str = self.classes[index]
        class_label = self.class_mapping[class_label_str]

        # Load MRI image and mask using nifti_to_np (assuming NIfTI format)
        # For 3D data, sliced=False
        image_np, affine = nifti_to_np(image_path, sliced=False, chosen_slice=None)
        mask_np, _ = nifti_to_np(mask_path, sliced=False, chosen_slice=None)

        self.original_shape = image_np.shape
        self.affine = affine

        # Convert numpy arrays to TorchIO Image and LabelMap
        # TorchIO expects (C, H, W, D) or (H, W, D) for 3D, and nifti_to_np returns (H, W, D)
        # So we need to add a channel dimension
        image_tio = torchio.Image(tensor=image_np[None, ...], affine=affine, type=torchio.INTENSITY)
        mask_tio = torchio.LabelMap(tensor=mask_np[None, ...], affine=affine)
        
        # Apply transformations. Assuming 3D transforms are needed for MRI.
        transform_params = get_params_3d(self.opt, image_tio.shape)
        c_transform = get_transform_torchio(self.opt, transform_params)

        image_transformed = c_transform(image_tio)
        mask_transformed = c_transform(mask_tio)

        # Ensure mask is binary (0 or 1)
        mask_transformed.data = (mask_transformed.data > 0).float()


        return {
            'A': image_transformed.data,  # Input image
            'mask': mask_transformed.data, # Corresponding mask
            'class_label': torch.tensor(class_label, dtype=torch.long), # Class label as tensor
            'A_paths': image_path,
            'mask_paths': mask_path
        }

    def __len__(self):
        return len(self.image_paths)


In [15]:
import nibabel as nib
import torch
import numpy as np
import os

# --- RandomFlip class (unchanged) ---
class RandomFlip:
    """Randomly flips along the specified axis."""
    def __init__(self, prob=0.5, spatial_axis=0):
        self.prob = prob
        self.axis = spatial_axis

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        if torch.rand(1).item() < self.prob:
            image = torch.flip(image, dims=[self.axis])
            label = torch.flip(label, dims=[self.axis])
        return {'image': image, 'label': label}


# --- Load NIfTI file ---
def load_nifti(filepath):
    nifti_img = nib.load(filepath)
    data = nifti_img.get_fdata()
    affine = nifti_img.affine
    return data, affine


# --- Save NIfTI file ---
def save_nifti(data, affine, out_path):
    nifti_img = nib.Nifti1Image(data, affine)
    nib.save(nifti_img, out_path)
    print(f"Saved: {out_path}")


# --- Main pipeline ---
def apply_flip_and_save(image_path, label_path, axis_list=[0, 1, 2], prob=1.0, output_dir='flipped_outputs'):
    os.makedirs(output_dir, exist_ok=True)

    # Load image and label
    image_np, affine = load_nifti(image_path)
    label_np, _ = load_nifti(label_path)

    # Convert to torch tensors
    image_tensor = torch.tensor(image_np, dtype=torch.float32)
    label_tensor = torch.tensor(label_np, dtype=torch.float32)

    # Apply flip on each axis
    for axis in axis_list:
        flipper = RandomFlip(prob=prob, spatial_axis=axis)
        flipped = flipper({'image': image_tensor, 'label': label_tensor})

        # Convert back to NumPy
        flipped_img = flipped['image'].numpy()
        flipped_lbl = flipped['label'].numpy()

        # Save outputs for visualization in 3D Slicer
        save_nifti(flipped_img, affine, os.path.join(output_dir, f'image_flip_axis{axis}.nii.gz'))
        save_nifti(flipped_lbl, affine, os.path.join(output_dir, f'label_flip_axis{axis}.nii.gz'))


# --- Example Usage ---
# Replace these paths with your own NIfTI files
image_path = r'C:\Users\nikit\Downloads\vs_gk_0000.nii'
label_path = r'C:\Users\nikit\Downloads\Struct_TUMOR.nii'

apply_flip_and_save(image_path, label_path, axis_list=[0, 1, 2], prob=1.0)  # set prob=1.0 to always flip


Saved: flipped_outputs\image_flip_axis0.nii.gz
Saved: flipped_outputs\label_flip_axis0.nii.gz
Saved: flipped_outputs\image_flip_axis1.nii.gz
Saved: flipped_outputs\label_flip_axis1.nii.gz
Saved: flipped_outputs\image_flip_axis2.nii.gz
Saved: flipped_outputs\label_flip_axis2.nii.gz
