# 3D U-Net for Kidney Tumor Segmentation

This notebook implements a 3D U-Net model for kidney tumor segmentation using the KiTS23 dataset.

In [None]:
! pip install segmentation_models_3D
! pip install tensorflow keras

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
from collections import Counter
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

import nibabel as nib
import segmentation_models_3D as sm

# Download Dataset

In [None]:
! git clone https://github.com/neheller/kits23.git /content/kits23/

In [None]:
import sys
from tqdm import tqdm
from pathlib import Path
import urllib.request
import shutil
from time import sleep
import concurrent.futures


TRAINING_CASE_NUMBERS = list(range(300)) + list(range(400, 589))

DST_PTH = Path("/content/kits23/dataset/")


def get_destination(case_id: str, create: bool = False):
    destination = DST_PTH / case_id / "imaging.nii.gz"
    if create:
        destination.parent.mkdir(exist_ok=True)
    return destination


def cleanup(tmp_pth: Path, e: Exception):
    if tmp_pth.exists():
        tmp_pth.unlink()

    if e is None:
        print("\nInterrupted.\n")
        sys.exit()
    raise e


def download_case(case_num: int, retry=True):
    remote_name = f"master_{case_num:05d}.nii.gz"
    url = f"https://kits19.sfo2.digitaloceanspaces.com/{remote_name}"
    destination = get_destination(f"case_{case_num:05d}", create=True)
    tmp_pth = destination.parent / f".partial.{destination.name}"
    try:
        urllib.request.urlretrieve(url, str(tmp_pth))
        shutil.move(str(tmp_pth), str(destination))
    except Exception as e:
        if retry:
            print(f"\nFailed to download case_{case_num:05d}. Retrying...")
            sleep(5)
            return download_case(case_num, retry=False)
        else:
            cleanup(tmp_pth, e)
    return case_num  # Return case number for progress tracking


def download_dataset():
    # Make output directory if it doesn't exist already
    DST_PTH.mkdir(exist_ok=True)

    # Determine which cases still need to be downloaded
    left_to_download = []
    for case_num in TRAINING_CASE_NUMBERS:
        case_id = f"case_{case_num:05d}"
        dst = get_destination(case_id)
        if not dst.exists():
            left_to_download.append(case_num)

    print(f"\nFound {len(left_to_download)} cases to download\n")
    # Use ThreadPoolExecutor to download multiple cases concurrently.
    max_workers = 8  # adjust this number based on your network and system
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Use tqdm to show overall progress
        results = list(tqdm(executor.map(download_case, left_to_download), total=len(left_to_download)))
    print("\nAll cases downloaded.")

In [None]:
download_dataset()

# Test Dataset

# Preprocess Dataset



## Remove Instances

In [None]:
import os
import glob
from tqdm import tqdm

# Define the dataset directory
kits23_path = "/content/kits23/dataset"  # Change this path if needed

# Get all case directories
case_dirs = sorted([d for d in os.listdir(kits23_path) if os.path.isdir(os.path.join(kits23_path, d))])

# Count total files to process for accurate overall progress
total_files = 0
for case_dir in case_dirs:
    case_path = os.path.join(kits23_path, case_dir)
    instance_files = glob.glob(os.path.join(case_path, "instances", "*.nii.gz"))
    total_files += len(instance_files)

# Create a single progress bar for all files
pbar = tqdm(total=total_files, desc="Deleting instance files")

# Iterate through each case directory
for case_dir in case_dirs:
    case_path = os.path.join(kits23_path, case_dir)

    # Find all instance segmentation files
    instance_files = glob.glob(os.path.join(case_path, "instances", "*.nii.gz"))

    # Delete instance files
    for file in instance_files:
        try:
            os.remove(file)
            pbar.update(1)  # Update the progress bar
            pbar.set_postfix(case=case_dir, file=os.path.basename(file))
        except Exception as e:
            print(f"Error deleting {file}: {e}")

    # Remove the 'instances' folder if it's empty
    instances_folder = os.path.join(case_path, "instances")
    if os.path.exists(instances_folder) and not os.listdir(instances_folder):
        try:
            os.rmdir(instances_folder)
        except Exception as e:
            print(f"Error removing directory {instances_folder}: {e}")

pbar.close()
print("Instance segmentation files deleted successfully.")

## Downsize and Decompress

In [None]:
import os
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom
import concurrent.futures
from tqdm import tqdm

def process_file(gz_path, target_shape=(96, 96, 96), image_order=1, label_order=0):
    """
    Processes a single .nii.gz file:
      - Loads it (on-the-fly decompression),
      - Downsizes it to target_shape,
      - Saves as an uncompressed .nii file,
      - Removes the original .nii.gz.

    Returns the new .nii file path.
    """
    # Load the compressed file (decompressed in memory)
    nifti_obj = nib.load(gz_path)
    vol_data = nifti_obj.get_fdata(dtype=np.float32)
    affine = nifti_obj.affine
    header = nifti_obj.header

    # Determine if file is a segmentation mask (for appropriate interpolation)
    basename = os.path.basename(gz_path).lower()
    is_segmentation = True #"seg" in basename or "label" in basename
    interp_order = label_order if is_segmentation else image_order

    # Get original spacing from header
    pixdim = header.get_zooms()[:3]  # Get pixel dimensions (spacing)

    # Calculate target spacing to maintain aspect ratio
    orig_shape = vol_data.shape[:3]

    # Option 1: Preserve aspect ratio based on smallest dimension
    # This maintains the physical size relationship between dimensions
    min_dim = np.argmin(orig_shape)
    scale_factor = target_shape[min_dim] / orig_shape[min_dim]
    new_shape = [int(round(s * scale_factor)) for s in orig_shape]

    # Adjust to exactly match target shape in all dimensions
    zoom_factors = [t / s for t, s in zip(target_shape, new_shape)]

    # Option 2: Direct resizing (if you prefer ignoring aspect ratio)
    # zoom_factors = [t / s for t, s in zip(target_shape, orig_shape)]

    # Apply scipy's zoom function for resampling
    # Important: Use order=0 for segmentation masks to prevent interpolation artifacts
    # print(f"Resizing {basename} from {orig_shape} to {target_shape} with order={interp_order}")

    # Apply resizing
    vol_data_resized = zoom(vol_data, zoom_factors, order=interp_order, mode='nearest')

    # Ensure exact target shape (sometimes zoom can be off by 1 pixel)
    if vol_data_resized.shape[:3] != target_shape:
        # Create a new array of target shape
        final_vol = np.zeros(target_shape + vol_data_resized.shape[3:], dtype=vol_data_resized.dtype)
        # Copy as much as fits
        slices = tuple(slice(0, min(t, s)) for t, s in zip(target_shape, vol_data_resized.shape[:3]))
        final_vol[slices] = vol_data_resized[slices]
        vol_data_resized = final_vol

    # For segmentation masks, ensure integer values are preserved
    if is_segmentation:
        vol_data_resized = np.round(vol_data_resized).astype(np.uint8)

    # Update header with new dimensions
    new_header = header.copy()
    # Update the zooms/pixdim to reflect the new voxel size
    new_zooms = tuple([p * o / t for p, o, t in zip(pixdim, orig_shape, target_shape)] + list(header.get_zooms()[3:]))
    new_header.set_zooms(new_zooms)

    # Save the downsized volume as an uncompressed .nii file (remove ".gz" extension)
    new_nii_path = gz_path[:-3]
    downsized_nifti = nib.Nifti1Image(vol_data_resized, affine, new_header)
    nib.save(downsized_nifti, new_nii_path)

    # Remove the original .nii.gz file
    os.remove(gz_path)

    return new_nii_path

def downsize_nii_gz_multithreaded(
    root_dir,
    target_shape=(128, 128, 128),
    image_order=1,
    label_order=0,
    max_workers=4
):
    """
    Finds all .nii.gz files under 'root_dir' and processes them in parallel.
    """
    # Gather all .nii.gz file paths from the directory tree.
    gz_files = []
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for filename in filenames:
            if filename.endswith(".nii.gz"):
                gz_files.append(os.path.join(dirpath, filename))

    print(f"Found {len(gz_files)} .nii.gz files to process.")

    # Use ProcessPoolExecutor to process files in parallel.
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(process_file, gz_path, target_shape, image_order, label_order): gz_path
            for gz_path in gz_files
        }

        # Use tqdm to show progress
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(gz_files), desc="Processing files"):
            try:
                result = future.result()
            except Exception as exc:
                print(f"File processing generated an exception: {exc}")

In [None]:
kits23_root = "/content/kits23/dataset"  # adjust this path as needed
# Example: downsize all volumes to 96x96x96; adjust max_workers based on your CPU cores.
downsize_nii_gz_multithreaded(kits23_root, target_shape=(128, 128, 128), max_workers=8)
print("All .nii.gz files processed: downsized and saved as .nii.")

Found 980 .nii.gz files to process.


Processing files: 100%|██████████| 980/980 [06:45<00:00,  2.42it/s]

All .nii.gz files processed: downsized and saved as .nii.





## Delete Label 3

In [None]:
import os
import numpy as np
import nibabel as nib
from pathlib import Path
import concurrent.futures
from tqdm import tqdm

# Configuration
ROOT_DIR = "/content/kits23/dataset/" # Change to your dataset path
DRY_RUN = False # Set to True to preview without changes
MAX_WORKERS = 8 # Number of parallel workers

def delete_last_label(file_path, dry_run=False):
    """
    Function to delete the last label in a 3D segmentation file, keeping only the first two labels.
    """
    try:
        file_path = Path(file_path)
        case_name = file_path.parent.name

        # Load the segmentation file
        seg_img = nib.load(file_path)
        seg_data = seg_img.get_fdata()

        # Get unique labels in the segmentation
        unique_labels = np.unique(seg_data)

        # If there are fewer than 3 labels, no action needed
        if len(unique_labels) < 3:
            return {
                'case_name': case_name,
                'filename': file_path.name,
                'modified': False,
                'message': 'Fewer than 3 labels present'
            }

        # Identify the last label
        last_label = unique_labels[-1]

        # Count voxels that will be deleted
        voxels_to_delete = np.sum(seg_data == last_label)

        # Create modified data by zeroing out the last label
        modified_data = seg_data.copy()
        modified_data[modified_data == last_label] = 0

        # Save if not dry run
        if not dry_run:
            modified_img = nib.Nifti1Image(modified_data, seg_img.affine, seg_img.header)
            nib.save(modified_img, str(file_path))

        return {
            'case_name': case_name,
            'filename': file_path.name,
            'modified': True,
            'last_label': int(last_label),
            'voxels_deleted': int(voxels_to_delete)
        }

    except Exception as e:
        return {
            'case_name': file_path.parent.name if isinstance(file_path, Path) else os.path.basename(os.path.dirname(file_path)),
            'filename': file_path.name if isinstance(file_path, Path) else os.path.basename(file_path),
            'error': str(e)
        }

def main():
    # Find all segmentation files (supporting .nii and .nii.gz)
    root_dir = Path(ROOT_DIR)
    seg_files = list(root_dir.glob("**/segmentation.nii*"))

    if not seg_files:
        print(f"No segmentation files found in {ROOT_DIR}")
        return

    print(f"Found {len(seg_files)} segmentation files to process")

    # Process files in parallel
    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = [executor.submit(delete_last_label, file_path, DRY_RUN) for file_path in seg_files]
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"):
            result = future.result()
            results.append(result)

    # Print summary
    files_modified = sum(1 for r in results if r.get('modified', False))
    files_with_errors = sum(1 for r in results if 'error' in r)
    total_voxels_deleted = sum(r.get('voxels_deleted', 0) for r in results if r.get('modified', False))

    print("\n=== SUMMARY ===")
    print(f"Total files processed: {len(results)}")
    print(f"Files modified (last label deleted): {files_modified}")
    print(f"Total voxels deleted: {total_voxels_deleted}")
    print(f"Files with errors: {files_with_errors}")

    if files_with_errors > 0:
        print("\nErrors:")
        for r in results:
            if 'error' in r:
                print(f" - {r['case_name']}/{r['filename']}: {r['error']}")

    if DRY_RUN:
        print("\nThis was a DRY RUN. No files were actually modified.")
        print("Set DRY_RUN = False to save changes.")

In [None]:
if __name__ == "__main__":
    main()

## Test Dataset

In [None]:
import os

def check_dataset_files(dataset_path):
    """
    Check a dataset directory for .nii or .nii.gz files, analyze file patterns,
    and identify missing files in each case folder.

    Args:
        dataset_path (str): Path to the dataset directory containing case folders

    Returns:
        dict: Dictionary with analysis results
    """

    # Ensure dataset_path exists
    if not os.path.exists(dataset_path):
        raise FileNotFoundError(f"Dataset path not found: {dataset_path}")

    # Lists to track various conditions
    case_folders = []
    total_nii_files = 0
    total_nii_gz_files = 0
    case_file_counts = {}
    common_filenames = set()
    missing_imaging = []
    missing_segmentation = []

    # Get all case folders
    case_folders = [f for f in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, f))]

    print(f"Found {len(case_folders)} case folders to check")

    # Check each case folder for files
    for case_folder in case_folders:
        case_path = os.path.join(dataset_path, case_folder)
        files = os.listdir(case_path)

        # Count different file types
        nii_files = [f for f in files if f.endswith('.nii')]
        nii_gz_files = [f for f in files if f.endswith('.nii.gz')]

        # Track filenames to identify patterns
        for f in nii_files + nii_gz_files:
            base_name = f.split('.')[0]  # Get name without extension
            common_filenames.add(base_name)

        # Update totals
        total_nii_files += len(nii_files)
        total_nii_gz_files += len(nii_gz_files)
        case_file_counts[case_folder] = {
            'nii': len(nii_files),
            'nii_gz': len(nii_gz_files),
            'total': len(nii_files) + len(nii_gz_files)
        }

        # Check for specific files (both .nii and .nii.gz versions)
        has_imaging = any(f == "imaging.nii" or f == "imaging.nii.gz" for f in files)
        has_segmentation = any(f == "segmentation.nii" or f == "segmentation.nii.gz" for f in files)

        if not has_imaging:
            missing_imaging.append(case_folder)
        if not has_segmentation:
            missing_segmentation.append(case_folder)

    # Print summary
    print(f"Total .nii files found: {total_nii_files}")
    print(f"Total .nii.gz files found: {total_nii_gz_files}")
    print(f"Combined total: {total_nii_files + total_nii_gz_files}")
    print(f"Expected total (if all cases have 2 files): {len(case_folders) * 2}")
    print(f"Common filenames in dataset: {sorted(common_filenames)}")
    print(f"Cases missing imaging files: {len(missing_imaging)}")
    print(f"Cases missing segmentation files: {len(missing_segmentation)}")

    # Find cases with abnormal file counts
    abnormal_cases = {case: info for case, info in case_file_counts.items() if info['total'] != 2}
    print(f"Cases with != 2 files: {len(abnormal_cases)}")

    return {
        "total_cases": len(case_folders),
        "total_nii_files": total_nii_files,
        "total_nii_gz_files": total_nii_gz_files,
        "common_filenames": sorted(common_filenames),
        "missing_imaging": missing_imaging,
        "missing_segmentation": missing_segmentation,
        "abnormal_cases": abnormal_cases
    }

In [None]:
results = check_dataset_files('/content/kits23/dataset')

print("\nCases with abnormal file counts:")
for case, info in sorted(results["abnormal_cases"].items()):
  print(f" - {case}: {info}")

# Dataloader

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
import random
from tqdm.notebook import tqdm
import nibabel as nib


class KiTS23Dataset:
    def __init__(self, dataset_path, target_shape=(96, 96, 96), use_cache=True, cache_limit=50):
        """
        Initialize the KiTS23 dataset loader

        Args:
            dataset_path (str): Path to the KiTS23 dataset
            target_shape (tuple): Target shape for the 3D volumes (will be resized)
            use_cache (bool): Whether to cache loaded data in memory
            cache_limit (int): Maximum number of cases to cache in memory
        """
        self.dataset_path = dataset_path
        self.target_shape = target_shape
        self.use_cache = use_cache
        self.cache_limit = cache_limit
        self.cache = {}

        # Find all case directories
        self.cases = []
        for item in os.listdir(dataset_path):
            case_dir = os.path.join(dataset_path, item)
            if os.path.isdir(case_dir):
                # Check if both imaging and segmentation files exist
                imaging_file = os.path.join(case_dir, "imaging.nii.gz")
                imaging_alt = os.path.join(case_dir, "imaging.nii")

                segmentation_file = os.path.join(case_dir, "segmentation.nii.gz")
                segmentation_alt = os.path.join(case_dir, "segmentation.nii")

                if (os.path.exists(imaging_file) or os.path.exists(imaging_alt)) and \
                   (os.path.exists(segmentation_file) or os.path.exists(segmentation_alt)):
                    self.cases.append(item)

        print(f"Found {len(self.cases)} valid cases with both imaging and segmentation data")

        # Get label statistics
        self.all_labels = self._get_all_labels()
        print(f"Found {len(self.all_labels)} unique labels: {sorted(list(self.all_labels))}")


    def _get_all_labels(self):
        """Get all unique labels in the segmentation files"""
        all_labels = set()
        for case in tqdm(self.cases[:min(10, len(self.cases))], desc="Analyzing labels"):
            seg_file = self._get_segmentation_path(case)
            if seg_file:
                try:
                    seg_data = self._load_nifti(seg_file)
                    unique_labels = np.unique(np.round(seg_data).astype(int))
                    all_labels.update(unique_labels)
                except Exception as e:
                    print(f"Error loading {seg_file}: {e}")

        return all_labels

    def _get_imaging_path(self, case_id):
        """Get path to imaging file for a case"""
        for ext in [".nii.gz", ".nii"]:
            path = os.path.join(self.dataset_path, case_id, f"imaging{ext}")
            if os.path.exists(path):
                return path
        return None

    def _get_segmentation_path(self, case_id):
        """Get path to segmentation file for a case"""
        for ext in [".nii.gz", ".nii"]:
            path = os.path.join(self.dataset_path, case_id, f"segmentation{ext}")
            if os.path.exists(path):
                return path
        return None

    def _load_nifti(self, file_path):
        """Load a NIfTI file and return the data array"""
        img = nib.load(file_path)
        return img.get_fdata()


    def _resize_volume(self, volume, target_shape=None):
        """Resize a 3D volume to target shape"""
        if target_shape is None:
            target_shape = self.target_shape

        # Get current shape
        current_shape = volume.shape

        # Calculate resize factors
        resize_factor = [target_shape[i] / current_shape[i] for i in range(3)]

        # Use scipy zoom for 3D resizing
        from scipy.ndimage import zoom
        resized_volume = zoom(volume, resize_factor, order=0 if len(volume.shape) == 3 else [0, 0, 0, 1])

        return resized_volume

    def _preprocess_volume(self, volume, is_mask=False):
      """Preprocess a volume by resizing and normalizing"""
      volume = np.array(volume)

      if is_mask:
          # Process masks normally
          volume = np.round(volume).astype(np.uint8)
          num_classes = len(self.all_labels)

          # Convert to one-hot encoding
          mask_classes = np.zeros(self.target_shape + (num_classes,), dtype=np.float32)
          for i, label in enumerate(sorted(self.all_labels)):
              mask_classes[..., i] = (volume == label).astype(np.float32)

          return mask_classes
      else:
          # Normalize image
          window_center = 30
          window_width = 400
          window_min = window_center - window_width // 2
          window_max = window_center + window_width // 2

          volume = np.clip(volume, window_min, window_max)
          volume = (volume - window_min) / (window_max - window_min)

          # 🚀 **Fix: Ensure 3 channels**
          if len(volume.shape) == 3:
              volume = np.expand_dims(volume, axis=-1)  # (96, 96, 96, 1)
              volume = np.repeat(volume, 3, axis=-1)    # (96, 96, 96, 3)

          return volume.astype(np.float32)




    def load_case(self, case_id):
        """
        Load imaging and segmentation data for a specific case

        Args:
            case_id (str): Case identifier

        Returns:
            tuple: (imaging_data, segmentation_data)
        """
        if self.use_cache and case_id in self.cache:
            return self.cache[case_id]

        imaging_path = self._get_imaging_path(case_id)
        segmentation_path = self._get_segmentation_path(case_id)

        if not imaging_path or not segmentation_path:
            raise ValueError(f"Could not find imaging or segmentation files for case {case_id}")

        # Load data
        imaging_data = self._load_nifti(imaging_path)
        segmentation_data = self._load_nifti(segmentation_path)

        # Preprocess
        imaging_processed = self._preprocess_volume(imaging_data)
        segmentation_processed = self._preprocess_volume(segmentation_data, is_mask=True)

        result = (imaging_processed, segmentation_processed)

        if self.use_cache:
            # Manage cache size
            if len(self.cache) >= self.cache_limit:
                # Remove a random item from cache
                remove_key = random.choice(list(self.cache.keys()))
                del self.cache[remove_key]

            self.cache[case_id] = result

        return result


    def _augment_pair(self, image, mask):
      return image, mask



    def batch_generator(self, case_ids, batch_size, preprocess_fn=None, augment=False):
      """Generator function to yield batches of data"""
      while True:
          # Shuffle cases for each epoch
          random.shuffle(case_ids)

          # Generate batches
          for i in range(0, len(case_ids), batch_size):
              batch_cases = case_ids[i:i+batch_size]

              images = []
              masks = []

              for case_id in batch_cases:
                  try:
                      image, mask = self.load_case(case_id)

                      if augment:
                          # Apply random augmentations
                          image, mask = self._augment_pair(image, mask)

                      images.append(image)
                      masks.append(mask)
                  except Exception as e:
                      print(f"Error loading case {case_id}: {e}")
                      continue

              if not images:
                  continue

              # Convert to numpy arrays
              if len(images) > 0:
                  expected_shape = images[0].shape  # Get the expected shape from the first image

                  # FIX: Proper resizing for mismatched shapes
                  for j in range(len(images)):
                      if images[j].shape != expected_shape:
                          # Use the proper resize method instead of np.resize
                          images[j] = self._resize_volume(images[j], expected_shape[:3])

                      if masks[j].shape != expected_shape:
                          # For masks, use order=0 to preserve label values (no interpolation)
                          masks[j] = self._resize_volume(masks[j], expected_shape[:3])

                  # Now convert to numpy arrays after proper resizing
                  images = np.array(images)
                  masks = np.array(masks)

              # Apply preprocessing function if provided
              if preprocess_fn:
                  images = preprocess_fn(images)

              yield images, masks

    def create_tf_datasets(self, batch_size=2, validation_split=0.2, test_split=0.1, preprocess_fn=None):
      """Create train, validation, and test datasets"""
      # Handle None test_split case
      if test_split is None or test_split == 0:
          # No test split, only train/validation
          train_cases, val_cases = train_test_split(self.cases, test_size=validation_split, random_state=42)
          test_cases = []  # Empty test set
      else:
          # Normal case with test split
          train_cases, test_cases = train_test_split(self.cases, test_size=test_split, random_state=42)
          train_cases, val_cases = train_test_split(train_cases, test_size=validation_split/(1-test_split), random_state=42)

      print(f"Train: {len(train_cases)} cases, Validation: {len(val_cases)} cases, Test: {len(test_cases)} cases")

      # Create generators
      train_gen = self.batch_generator(train_cases, batch_size, preprocess_fn, augment=False)
      val_gen = self.batch_generator(val_cases, batch_size, preprocess_fn, augment=False)

      # Create test generator only if we have test cases
      test_gen = None
      if test_cases:
          test_gen = self.batch_generator(test_cases, batch_size, preprocess_fn, augment=False)

      # Calculate steps per epoch
      steps_per_epoch = max(1, len(train_cases) // batch_size)
      validation_steps = max(1, len(val_cases) // batch_size)
      test_steps = max(1, len(test_cases) // batch_size) if test_cases else 0

      return train_gen, val_gen, test_gen, steps_per_epoch, validation_steps, test_steps

# Model Training

In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import segmentation_models_3D as sm
from tensorflow.keras.optimizers import AdamW
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger, EarlyStopping
from tensorflow.keras import backend as K
from tensorflow.keras import backend as K

## Loss Functions

In [None]:
# Dice coefficient without background
def kidney_tumor_dice(y_true, y_pred):
    # Remove background channel (assuming it's index 0)
    y_true_no_bg = y_true[..., 1:]  # Kidney and tumor only
    y_pred_no_bg = y_pred[..., 1:]  # Kidney and tumor only

    # Calculate Dice
    intersection = tf.reduce_sum(y_true_no_bg * y_pred_no_bg, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true_no_bg, axis=[1, 2, 3]) + tf.reduce_sum(y_pred_no_bg, axis=[1, 2, 3])

    # Add smooth factor for stability
    smooth = 1e-6
    dice = (2.0 * intersection + smooth) / (union + smooth)

    return tf.reduce_mean(dice)

# IoU/Jaccard without background
def kidney_tumor_iou(y_true, y_pred):
    # Remove background channel
    y_true_no_bg = y_true[..., 1:]
    y_pred_no_bg = y_pred[..., 1:]

    # Calculate IoU
    intersection = tf.reduce_sum(y_true_no_bg * y_pred_no_bg, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true_no_bg, axis=[1, 2, 3]) + tf.reduce_sum(y_pred_no_bg, axis=[1, 2, 3]) - intersection

    # Add smooth factor
    smooth = 1e-6
    iou = (intersection + smooth) / (union + smooth)

    return tf.reduce_mean(iou)

# You can also calculate per-class metrics
def kidney_dice(y_true, y_pred):
    # Extract kidney channel (assuming it's index 1)
    y_true_kidney = y_true[..., 1]
    y_pred_kidney = y_pred[..., 1]

    # Calculate Dice for kidney
    intersection = tf.reduce_sum(y_true_kidney * y_pred_kidney, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true_kidney, axis=[1, 2, 3]) + tf.reduce_sum(y_pred_kidney, axis=[1, 2, 3])

    smooth = 1e-6
    dice = (2.0 * intersection + smooth) / (union + smooth)

    return tf.reduce_mean(dice)

def tumor_dice(y_true, y_pred):
    # Extract tumor channel (assuming it's index 2)
    y_true_tumor = y_true[..., 2]
    y_pred_tumor = y_pred[..., 2]

    # Calculate Dice for tumor
    intersection = tf.reduce_sum(y_true_tumor * y_pred_tumor, axis=[1, 2, 3])
    union = tf.reduce_sum(y_true_tumor, axis=[1, 2, 3]) + tf.reduce_sum(y_pred_tumor, axis=[1, 2, 3])

    smooth = 1e-6
    dice = (2.0 * intersection + smooth) / (union + smooth)

    return tf.reduce_mean(dice)

In [None]:
# Modified volumetric similarity focused on kidney and tumor
def kidney_tumor_vol_sim(y_true, y_pred):
    # Remove background
    y_true_no_bg = y_true[..., 1:]
    y_pred_no_bg = K.cast(K.greater(y_pred[..., 1:], 0.5), 'float32')

    # Calculate volumes
    vol_true = K.sum(y_true_no_bg, axis=[1, 2, 3])
    vol_pred = K.sum(y_pred_no_bg, axis=[1, 2, 3])

    # Calculate volumetric similarity with handling for empty volumes
    epsilon = K.epsilon()
    vs = 1.0 - K.abs(vol_true - vol_pred) / (vol_true + vol_pred + epsilon)

    # Average over batch and classes
    return K.mean(vs)

In [None]:
def kidney_tumor_hausdorff(y_true, y_pred, max_dist=100.0):
    """
    Calculate approximate Hausdorff distance for kidney and tumor classes only
    """
    # Process only kidney and tumor channels (skip background)
    y_true_no_bg = y_true[..., 1:]  # Kidney and tumor only
    y_pred_no_bg = y_pred[..., 1:]  # Kidney and tumor only

    # Convert predictions to binary
    y_pred_binary = K.cast(K.greater(y_pred_no_bg, 0.5), 'float32')
    y_true_binary = K.cast(y_true_no_bg, 'float32')

    # Function to compute Hausdorff approximation
    def _hausdorff(y_t_p):
        y_t, y_p = y_t_p  # Unpack tuple

        # Find positive pixels for each class independently
        hausdorff_distances = []

        # Loop through kidney and tumor classes
        for c in range(K.int_shape(y_t)[-1]):  # Either 2 classes (kidney, tumor)
            # Extract single class
            y_t_class = y_t[..., c]
            y_p_class = y_p[..., c]

            # Find positive pixels
            y_t_pos = K.cast(K.greater(y_t_class, 0.5), 'float32')
            y_p_pos = K.cast(K.greater(y_p_class, 0.5), 'float32')

            # Compute sums to check if this class exists in this sample
            sum_t = K.sum(y_t_pos)
            sum_p = K.sum(y_p_pos)

            # If either set is empty, use max_dist for this class
            class_dist = tf.cond(
                tf.logical_or(tf.equal(sum_t, 0), tf.equal(sum_p, 0)),
                lambda: tf.constant(max_dist, dtype=tf.float32),
                lambda: _compute_hausdorff(y_t_pos, y_p_pos, max_dist)
            )

            hausdorff_distances.append(class_dist)

        # Return mean of distances across classes
        return tf.reduce_mean(hausdorff_distances)

    def _compute_hausdorff(y_t_pos, y_p_pos, max_dist):
        # Flatten tensors
        y_t_pos_flat = K.flatten(y_t_pos)
        y_p_pos_flat = K.flatten(y_p_pos)

        # Compute precision and recall (efficient approximation)
        precision = K.sum(y_p_pos_flat * y_t_pos_flat) / (K.sum(y_p_pos_flat) + K.epsilon())
        recall = K.sum(y_p_pos_flat * y_t_pos_flat) / (K.sum(y_t_pos_flat) + K.epsilon())

        # Approximate Hausdorff distance
        approx_hausdorff = max_dist * (2 - precision - recall)
        return approx_hausdorff

    # Apply function over batch
    return K.mean(tf.map_fn(lambda y_t_p: _hausdorff(y_t_p), (y_true_binary, y_pred_binary), dtype=tf.float32))

## Training Function

In [None]:
def get_model_memory_usage(batch_size, model):
    """Calculate approximate GPU memory usage of a model in GB"""
    try:
        import numpy as np
        from keras import backend as K

        # Function to count parameters
        def count_params(weights):
            return sum(np.prod(w.shape) for w in weights)

        # Get trainable and non-trainable params
        trainable_count = count_params(model.trainable_weights)
        non_trainable_count = count_params(model.non_trainable_weights)

        # Calculate activation memory
        shapes_mem_count = 0
        internal_model_mem_count = 0
        for l in model.layers:
            layer_type = l.__class__.__name__

            # Skip input layers or handle them specially
            if layer_type == 'InputLayer':
                continue

            if layer_type == 'Model' or layer_type == 'Functional':
                internal_model_mem_count += get_model_memory_usage(batch_size, l)
                continue

            # Get the output shape
            single_layer_mem = 1
            try:
                # Try different attribute names for output shape
                if hasattr(l, 'output_shape'):
                    out_shape = l.output_shape
                elif hasattr(l, 'output'):
                    out_shape = l.output.shape
                elif hasattr(l, '_output_shape'):
                    out_shape = l._output_shape
                else:
                    # If we can't determine the output shape, make a conservative estimate
                    print(f"Warning: Could not determine output shape for layer {l.name}")
                    continue

                # Handle lists of shapes (multiple outputs)
                if isinstance(out_shape, list):
                    out_shape = out_shape[0]

                # Calculate memory for this layer's activations
                for s in out_shape:
                    if s is None:
                        continue
                    single_layer_mem *= s

                shapes_mem_count += single_layer_mem
            except Exception as e:
                print(f"Warning: Error processing layer {l.name}: {e}")
                continue

        # Determine size of each parameter based on dtype
        number_size = 4.0  # Default float32
        if K.floatx() == 'float16':
            number_size = 2.0
        if K.floatx() == 'float64':
            number_size = 8.0

        # Calculate total memory
        total_memory = number_size * (batch_size * shapes_mem_count + trainable_count + non_trainable_count)
        gbytes = np.round(total_memory / (1024.0 ** 3), 3) + internal_model_mem_count
        return gbytes
    except Exception as e:
        print(f"Error calculating model memory usage: {e}")
        return "Memory estimation failed"

In [None]:
def train_kits23_model(dataset_path, output_dir="models", model_name="kits23_unet",
                       batch_size=2, epochs=50, learning_rate=0.0001, patience=10,
                       backbone="resnet18", target_shape=(96, 96, 96)):
    """
    Train a 3D U-Net model on KiTS23 dataset with additional evaluation metrics

    Args:
        dataset_path (str): Path to the KiTS23 dataset
        output_dir (str): Directory to save models and logs
        model_name (str): Base name for the model files
        batch_size (int): Batch size for training
        epochs (int): Number of training epochs
        learning_rate (float): Initial learning rate
        patience (int): Patience for early stopping
        backbone (str): Backbone architecture for the U-Net
        target_shape (tuple): Target shape for the input volumes
    """
    #from tensorflow.keras.mixed_precision import set_global_policy
    #set_global_policy('mixed_float16')

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Initialize dataset
    dataset = KiTS23Dataset(
        dataset_path=dataset_path,
        target_shape=target_shape,
        use_cache=True,
        cache_limit=50
    )

    # Get number of classes
    num_classes = len(dataset.all_labels)
    print(f"Training model with {num_classes} classes")

    # Get preprocessing function for the backbone
    preprocess_input = sm.get_preprocessing(backbone)

    # Create dataset generators
    train_gen, val_gen, test_gen, steps_per_epoch, validation_steps, _ = dataset.create_tf_datasets(
        batch_size=batch_size,
        validation_split=0.2,
        test_split=0,
        preprocess_fn=preprocess_input
    )

    # Setup model
    shape_size = (*target_shape, 3)  # 3 channels
    encoder_weights = 'imagenet'

    # Create model
    model = sm.Unet(
        backbone,
        input_shape=shape_size,
        encoder_weights=encoder_weights,
        classes=num_classes,
    )

    # Display model summary and memory usage
    print(model.summary())
    print(f"Estimated GPU memory usage: {get_model_memory_usage(batch_size, model)} GB")

    # Setup optimizer and loss
    lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
        initial_learning_rate=learning_rate,
        first_decay_steps=steps_per_epoch * 5,
        t_mul=1.5,
        m_mul=0.95,
        alpha=0.001
    )

    optim = AdamW(
        learning_rate=lr_schedule,
        weight_decay=1e-5,
        clipnorm=1.0
    )

    def kidney_tumor_only_loss(y_true, y_pred, gamma=0.75):
      """
      Focal Tversky loss that only considers kidney and tumor classes
      """
      # Remove background channel (assuming it's the first channel)
      y_true_no_bg = y_true[..., 1:]  # Keep only kidney and tumor channels
      y_pred_no_bg = y_pred[..., 1:]  # Keep only kidney and tumor channels

      # Class weights: kidney, tumor
      class_weights = tf.constant([1.0, 5.0], dtype=tf.float32)

      # Apply weights to each class channel
      weighted_y_true = y_true_no_bg * tf.reshape(class_weights, [1, 1, 1, 1, -1])

      # Calculate Tversky components only for kidney and tumor
      smooth = 1e-6
      y_true_pos = weighted_y_true
      y_pred_pos = y_pred_no_bg

      true_pos = tf.reduce_sum(y_true_pos * y_pred_pos, axis=[1, 2, 3])
      false_neg = tf.reduce_sum(y_true_pos * (1 - y_pred_pos), axis=[1, 2, 3])
      false_pos = tf.reduce_sum((1 - y_true_pos) * y_pred_pos, axis=[1, 2, 3])

      alpha = 0.3
      beta = 0.7

      tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth)
      tversky_loss = 1 - tversky

      # Apply focal component
      focal_tversky = tf.pow(tversky_loss, gamma)

      return tf.reduce_mean(focal_tversky)



    # Compile model with all metrics
    model.compile(
        optimizer=optim,
        loss=kidney_tumor_only_loss,
        metrics=[
            kidney_tumor_dice,
            kidney_tumor_iou,
            kidney_dice,
            tumor_dice,
            kidney_tumor_vol_sim,
            kidney_tumor_hausdorff
        ]
    )

    # Setup callbacks
    model_path = os.path.join(output_dir, f"{model_name}")
    cache_model_path = f'{model_path}_temp.keras'
    best_model_path = os.path.join(output_dir, f"{model_name}_best.keras")
    epoch_model_path = os.path.join(output_dir, f"{model_name}_epoch{{epoch:02d}}.keras")

    log_file = os.path.join(output_dir, f'history_{model_name}.csv')

    callbacks = [
        ModelCheckpoint(cache_model_path, monitor='val_loss', verbose=1),
        ModelCheckpoint(best_model_path, monitor='val_kidney_tumor_iou', verbose=1, save_best_only=True, mode='max'),
        #ReduceLROnPlateau(monitor='val_kidney_tumor_iou', factor=0.9, patience=5, min_lr=1e-7, min_delta=1e-6, verbose=1, mode='max'),
        CSVLogger(log_file, append=True),
        EarlyStopping(monitor='val_kidney_tumor_iou', patience=patience, verbose=1, mode='max', restore_best_weights=True),
    ]

    # Fit model
    history = model.fit(
        train_gen,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        validation_data=val_gen,
        validation_steps=validation_steps,
        verbose=1,
        callbacks=callbacks
    )

    # Evaluate model
    max_iou = max(history.history['val_kidney_tumor_iou'])
    print(f'Training finished. Max Kidney-Tumor IoU: {max_iou:.4f}')

    # Plot training history
    plt.figure(figsize=(15, 10))

    # Plot loss
    plt.subplot(2, 3, 1)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot Kidney-Tumor Dice coefficient
    plt.subplot(2, 3, 2)
    plt.plot(history.history['kidney_tumor_dice'], label='Training Dice')
    plt.plot(history.history['val_kidney_tumor_dice'], label='Validation Dice')
    plt.title('Kidney-Tumor Dice Coefficient')
    plt.xlabel('Epoch')
    plt.ylabel('Dice')
    plt.legend()

    # Plot Kidney-Tumor IoU
    plt.subplot(2, 3, 3)
    plt.plot(history.history['kidney_tumor_iou'], label='Training IoU')
    plt.plot(history.history['val_kidney_tumor_iou'], label='Validation IoU')
    plt.title('Kidney-Tumor IoU')
    plt.xlabel('Epoch')
    plt.ylabel('IoU')
    plt.legend()

    # Plot Kidney vs Tumor Dice
    plt.subplot(2, 3, 4)
    plt.plot(history.history['kidney_dice'], label='Kidney Dice')
    plt.plot(history.history['val_kidney_dice'], label='Val Kidney Dice')
    plt.plot(history.history['tumor_dice'], label='Tumor Dice')
    plt.plot(history.history['val_tumor_dice'], label='Val Tumor Dice')
    plt.title('Kidney vs Tumor Dice')
    plt.xlabel('Epoch')
    plt.ylabel('Dice')
    plt.legend()

    # Plot Hausdorff distance
    plt.subplot(2, 3, 5)
    plt.plot(history.history['kidney_tumor_hausdorff'], label='Training Hausdorff')
    plt.plot(history.history['val_kidney_tumor_hausdorff'], label='Val Hausdorff')
    plt.title('Hausdorff Distance')
    plt.xlabel('Epoch')
    plt.ylabel('Distance')
    plt.legend()

    # Plot Volumetric similarity
    plt.subplot(2, 3, 6)
    plt.plot(history.history['kidney_tumor_vol_sim'], label='Training Vol. Similarity')
    plt.plot(history.history['val_kidney_tumor_vol_sim'], label='Val Vol. Similarity')
    plt.title('Volumetric Similarity')
    plt.xlabel('Epoch')
    plt.ylabel('Similarity')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'{model_name}_history.png'))
    plt.show()

    # Save the best model
    final_model_path = os.path.join(output_dir, f"{model_name}_final.keras")  # Final best model for deployment
    model.save(final_model_path)

    return model, history

## Train

In [None]:
dataset_path = '/content/kits23/dataset/'
output_dir = '/content/models'
model_name = 'kits23_unet'
batch_size = 8
epochs = 128
learning_rate = 0.0001
patience = 40
backbone = 'resnet18'  # You can try other backbones like 'resnet34', 'vgg16', etc.
target_shape = (128, 128, 128)

In [None]:
model, history = train_kits23_model(
    dataset_path=dataset_path,
    output_dir=output_dir,
    model_name=model_name,
    batch_size=batch_size,
    epochs=epochs,
    learning_rate=learning_rate,
    patience=patience,
    backbone=backbone,
    target_shape=target_shape
)

# Test Visualization

In [None]:
def visualize_sample(dataset, case_id=None):
    """Visualize a sample case from the dataset with more detailed views"""
    if case_id is None:
        case_id = random.choice(dataset.cases)

    print(f"Visualizing case: {case_id}")

    # Load the case
    image, mask = dataset.load_case(case_id)

    # Get original files for inspection
    imaging_path = dataset._get_imaging_path(case_id)
    segmentation_path = dataset._get_segmentation_path(case_id)

    print(f"Original files: \nImage: {imaging_path}\nSegmentation: {segmentation_path}")

    # Load original data to check
    try:
        orig_img = dataset._load_nifti(imaging_path)
        orig_seg = dataset._load_nifti(segmentation_path)
        print(f"Original image shape: {orig_img.shape}")
        print(f"Original mask shape: {orig_seg.shape}")
        print(f"Original mask unique values: {np.unique(orig_seg)}")
    except Exception as e:
        print(f"Error loading original files: {e}")

    # Print processed data info
    print(f"Processed image shape: {image.shape}")
    print(f"Processed mask shape: {mask.shape}")

    # Get multiple slice views (start, middle, end)
    slice_indices = [
        image.shape[0] // 4,
        image.shape[0] // 2,
        3 * image.shape[0] // 4
    ]

    # Create a figure with multiple views
    fig, axes = plt.subplots(3, 3, figsize=(15, 12))

    for row, slice_idx in enumerate(slice_indices):
        # Show image
        axes[row, 0].imshow(image[slice_idx, :, :, 0], cmap='gray')
        axes[row, 0].set_title(f"CT Slice {slice_idx}")
        axes[row, 0].axis('off')

        # Show segmentation as separate channels
        # Background (class 0)
        if mask.shape[-1] > 0:
            axes[row, 1].imshow(mask[slice_idx, :, :, 0], cmap='Blues')
            axes[row, 1].set_title(f"Background (class 0)")
            axes[row, 1].axis('off')

        # Combined kidney/tumor visualization with better colors
        mask_img = np.zeros((*mask.shape[1:3], 3))
        if mask.shape[-1] > 1:  # Kidney (class 1) - red
            mask_img[:, :, 0] += mask[slice_idx, :, :, 1]
        if mask.shape[-1] > 2:  # Tumor (class 2) - green
            mask_img[:, :, 1] += mask[slice_idx, :, :, 2]

        axes[row, 2].imshow(image[slice_idx, :, :, 0], cmap='gray')
        axes[row, 2].imshow(mask_img, alpha=0.5)
        axes[row, 2].set_title(f"Overlay (red=kidney, green=tumor)")
        axes[row, 2].axis('off')

    plt.tight_layout()
    plt.savefig(f'visualization_{case_id}.png')
    plt.show()

    # Also show some orthogonal views (sagittal and coronal)
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))

    # Middle slices in different orientations
    mid_x = image.shape[1] // 2
    mid_y = image.shape[2] // 2
    mid_z = image.shape[0] // 2

    # Coronal view (x-z plane)
    axes[0, 0].imshow(image[:, mid_y, :, 0], cmap='gray')
    axes[0, 0].set_title("Coronal CT View")
    axes[0, 0].axis('off')

    # Sagittal view (y-z plane)
    axes[0, 1].imshow(image[:, :, mid_x, 0], cmap='gray')
    axes[0, 1].set_title("Sagittal CT View")
    axes[0, 1].axis('off')

    # Axial view (x-y plane) - already shown above
    axes[0, 2].imshow(image[mid_z, :, :, 0], cmap='gray')
    axes[0, 2].set_title("Axial CT View")
    axes[0, 2].axis('off')

    # Mask overlays for the same views
    mask_img_coronal = np.zeros((image.shape[0], image.shape[2], 3))
    mask_img_sagittal = np.zeros((image.shape[0], image.shape[1], 3))
    mask_img_axial = np.zeros((image.shape[1], image.shape[2], 3))

    if mask.shape[-1] > 1:  # Kidney (class 1)
        mask_img_coronal[:, :, 0] = mask[:, mid_y, :, 1]
        mask_img_sagittal[:, :, 0] = mask[:, :, mid_x, 1]
        mask_img_axial[:, :, 0] = mask[mid_z, :, :, 1]

    if mask.shape[-1] > 2:  # Tumor (class 2)
        mask_img_coronal[:, :, 1] = mask[:, mid_y, :, 2]
        mask_img_sagittal[:, :, 1] = mask[:, :, mid_x, 2]
        mask_img_axial[:, :, 1] = mask[mid_z, :, :, 2]

    # Show overlays
    axes[1, 0].imshow(image[:, mid_y, :, 0], cmap='gray')
    axes[1, 0].imshow(mask_img_coronal, alpha=0.5)
    axes[1, 0].set_title("Coronal Mask Overlay")
    axes[1, 0].axis('off')

    axes[1, 1].imshow(image[:, :, mid_x, 0], cmap='gray')
    axes[1, 1].imshow(mask_img_sagittal, alpha=0.5)
    axes[1, 1].set_title("Sagittal Mask Overlay")
    axes[1, 1].axis('off')

    axes[1, 2].imshow(image[mid_z, :, :, 0], cmap='gray')
    axes[1, 2].imshow(mask_img_axial, alpha=0.5)
    axes[1, 2].set_title("Axial Mask Overlay")
    axes[1, 2].axis('off')

    plt.tight_layout()
    plt.savefig(f'visualization_orthogonal_{case_id}.png')
    plt.show()

In [None]:
# Initialize dataset for visualization
dataset = KiTS23Dataset(dataset_path=dataset_path, target_shape=target_shape)
visualize_sample(dataset)

In [None]:
! zip -r models.zip /content/models/

# Save Model to Drive

In [None]:
from google.colab import drive
import os
import shutil

def save_to_google_drive(model, model_dir, model_name, training_config, drive_folder="KiTS23_Models"):
    """
    Saves the trained model using the `model` variable, CSV history file, training plot,
    and training config to Google Drive.

    Args:
        model (tf.keras.Model): The trained Keras model to save.
        model_dir (str): The directory where the model and logs are stored.
        model_name (str): The base name of the model files (without extension).
        training_config (dict): Dictionary containing training parameters.
        drive_folder (str): The name of the folder in Google Drive to save files.

    Returns:
        str: The full path to the Google Drive folder.
    """

    # 1️⃣ Mount Google Drive
    drive.mount('/content/drive')

    # 2️⃣ Define the Google Drive path
    drive_path = f"/content/drive/My Drive/{drive_folder}"

    # Create folder if it doesn't exist
    if not os.path.exists(drive_path):
        os.makedirs(drive_path)

    # 3️⃣ Define file paths
    model_path = os.path.join(model_dir, f"{model_name}.keras")  # Trained model file
    history_csv = os.path.join(model_dir, f"history_{model_name}.csv")  # CSV file
    history_png = os.path.join(model_dir, f"{model_name}_history.png")  # Plot image
    config_txt = os.path.join(model_dir, f"{model_name}_config.txt")  # Training config

    # 4️⃣ Save the model using model.save()
    print("🔄 Saving model...")
    model.save(model_path)  # ✅ This ensures the latest trained model is saved
    print(f"✅ Model saved at: {model_path}")

    # 5️⃣ Save training configuration as a text file
    with open(config_txt, "w") as f:
        for key, value in training_config.items():
            f.write(f"{key}: {value}\n")
    print(f"✅ Training config saved: {config_txt}")

    # 6️⃣ Move files to Google Drive
    for file_path in [model_path, history_csv, history_png, config_txt]:
        if os.path.exists(file_path):
            shutil.copy(file_path, drive_path)
            print(f"✅ Saved {os.path.basename(file_path)} to Google Drive: {drive_path}")

    print("🚀 All files successfully saved to Google Drive!")

    return drive_path  # Return the folder path


In [None]:
# Define training configuration
training_config = {
    "dataset_path": dataset_path,
    "output_dir": output_dir,
    "model_name": model_name,
    "batch_size": batch_size,
    "epochs": epochs,
    "learning_rate": learning_rate,
    "patience": patience,
    "backbone": backbone,
    "target_shape": target_shape,
}

# Save model & logs to Google Drive
#save_to_google_drive(model, output_dir, model_name, training_config, drive_folder="3dunet-31-march")
