In [None]:
# Import libraries
import os
import sys
import subprocess
import tarfile
import shutil
import requests
from pathlib import Path
import tempfile
import json
import numpy as np
from tqdm import tqdm

print("✅ Dependencies installed and imported")


In [None]:
# Set up nnU-Net environment variables
notebook_dir = Path.cwd()
project_root = notebook_dir.parent  # Assuming notebook is in notebooks/ folder
datasets_path = project_root / "datasets"

# Set environment variables
env_vars = {
    'nnUNet_raw': str(datasets_path / "nnUNet_raw"),
    'nnUNet_preprocessed': str(datasets_path / "nnUNet_preprocessed"),
    'nnUNet_results': str(datasets_path / "nnUNet_results")
}

for key, value in env_vars.items():
    os.environ[key] = value
    Path(value).mkdir(parents=True, exist_ok=True)
    print(f"✅ Set {key} = {value}")

print("\n🔍 Environment verification:")
for key in env_vars.keys():
    print(f"   {key}: {os.environ[key]}")


In [None]:
def download_file_with_progress(url, destination):
    """Download a file with progress bar."""
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        
        with open(destination, 'wb') as file, tqdm(
            desc=Path(destination).name,
            total=total_size,
            unit='B',
            unit_scale=True,
            unit_divisor=1024,
        ) as pbar:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    file.write(chunk)
                    pbar.update(len(chunk))
        return True
    except Exception as e:
        print(f"❌ Download failed: {e}")
        return False

# Try to download MSD Heart dataset
download_dir = Path(tempfile.mkdtemp(prefix="nnunet_download_"))
print(f"📁 Using temporary directory: {download_dir}")

# MSD download URLs (may need updating)
msd_urls = [
    "http://medicaldecathlon.com/files/Task02_Heart.tar",
]

tar_file = download_dir / "Task02_Heart.tar"
downloaded = False

print("📥 Attempting to download MSD Task02_Heart...")
for url in msd_urls:
    print(f"🌐 Trying: {url}")
    if download_file_with_progress(url, tar_file):
        downloaded = True
        print("✅ Download successful!")
        break

if not downloaded:
    print("⚠️ MSD download failed. Will create synthetic dataset instead.")
else:
    print(f"📦 Downloaded to: {tar_file}")


In [None]:
# Install nibabel for creating NIfTI files
%pip install nibabel
import nibabel as nib


In [None]:
def create_synthetic_dataset(base_dir, num_cases=10):
    """Create a synthetic dataset for testing nnU-Net."""
    
    print(f"🔬 Creating synthetic dataset with {num_cases} cases...")
    
    dataset_dir = base_dir / "Dataset999_Synthetic"
    images_tr = dataset_dir / "imagesTr"
    labels_tr = dataset_dir / "labelsTr"
    
    # Create directories
    images_tr.mkdir(parents=True, exist_ok=True)
    labels_tr.mkdir(parents=True, exist_ok=True)
    
    # Create synthetic data
    for i in tqdm(range(num_cases), desc="Creating cases"):
        case_id = f"case_{i:03d}"
        
        # Create 3D image (64x64x32 for quick processing)
        image_data = np.random.randint(0, 1000, (64, 64, 32), dtype=np.int16)
        # Add some structure (simulated "organ")
        image_data[20:40, 20:40, 10:20] += 200
        
        # Create segmentation mask
        seg_data = np.zeros((64, 64, 32), dtype=np.uint8)
        seg_data[25:35, 25:35, 12:18] = 1  # "organ" label
        
        # Save as NIfTI files
        img_nifti = nib.Nifti1Image(image_data, affine=np.eye(4))
        seg_nifti = nib.Nifti1Image(seg_data, affine=np.eye(4))
        
        nib.save(img_nifti, images_tr / f"{case_id}_0000.nii.gz")
        nib.save(seg_nifti, labels_tr / f"{case_id}.nii.gz")
    
    # Create dataset.json
    dataset_json = {
        "channel_names": {"0": "synthetic"},
        "labels": {"background": 0, "organ": 1},
        "numTraining": num_cases,
        "file_ending": ".nii.gz"
    }
    
    with open(dataset_dir / "dataset.json", 'w') as f:
        json.dump(dataset_json, f, indent=2)
    
    print(f"✅ Created synthetic dataset at: {dataset_dir}")
    return dataset_dir

# Create synthetic dataset if needed
if not downloaded:
    dataset_path = create_synthetic_dataset(download_dir)
    dataset_id = 999
    print(f"📊 Synthetic dataset ready at: {dataset_path}")
else:
    # Extract the downloaded MSD dataset
    print("📂 Extracting MSD dataset...")
    with tarfile.open(tar_file, 'r') as tar:
        tar.extractall(download_dir)
    
    # Find extracted folder
    extracted_folders = [d for d in download_dir.iterdir() if d.is_dir() and 'Task02' in d.name]
    if extracted_folders:
        dataset_path = extracted_folders[0]
        dataset_id = 2
        print(f"✅ Extracted MSD dataset at: {dataset_path}")
    else:
        print("⚠️ Extraction failed, creating synthetic dataset instead")
        dataset_path = create_synthetic_dataset(download_dir)
        dataset_id = 999


In [None]:
def convert_to_nnunet_format(dataset_path, target_id):
    """Convert dataset to nnU-Net format."""
    print(f"🔄 Converting dataset to nnU-Net format (ID: {target_id})...")
    
    raw_data_folder = Path(os.environ['nnUNet_raw'])
    
    if 'Task02_Heart' in str(dataset_path):
        # Use MSD converter for official MSD data
        cmd = [
            'nnUNetv2_convert_MSD_dataset',
            '-i', str(dataset_path),
            '-overwrite_id', str(target_id)
        ]
        
        try:
            result = subprocess.run(cmd, check=True, capture_output=True, text=True)
            print("✅ MSD conversion successful")
            return target_id
        except subprocess.CalledProcessError as e:
            print(f"❌ MSD conversion failed: {e}")
            print(f"Error output: {e.stderr}")
            return None
    else:
        # For synthetic datasets, copy directly
        target_name = f"Dataset{target_id:03d}_Synthetic"
        target_path = raw_data_folder / target_name
        
        if target_path.exists():
            shutil.rmtree(target_path)
        
        shutil.copytree(dataset_path, target_path)
        print(f"✅ Copied synthetic dataset to: {target_path}")
        return target_id

# Convert dataset
converted_id = convert_to_nnunet_format(dataset_path, dataset_id)

if converted_id:
    print(f"🎯 Dataset successfully converted with ID: {converted_id}")
else:
    print("❌ Dataset conversion failed")
    raise RuntimeError("Cannot proceed without successful conversion")


In [None]:
def run_preprocessing(dataset_id):
    """Run nnU-Net preprocessing."""
    print(f"⚙️ Running nnU-Net preprocessing for dataset {dataset_id}...")
    print("This may take a few minutes...")
    
    cmd = [
        'nnUNetv2_plan_and_preprocess',
        '-d', str(dataset_id),
        '--verify_dataset_integrity'
    ]
    
    try:
        # Run preprocessing (this may take a while)
        result = subprocess.run(cmd, check=True, capture_output=True, text=True)
        print("✅ Preprocessing completed successfully!")
        print("🔍 Dataset integrity verified")
        
        # Show some output
        if result.stdout:
            print("\n📋 Preprocessing summary:")
            # Show last few lines of output
            lines = result.stdout.strip().split('\n')
            for line in lines[-5:]:
                if line.strip():
                    print(f"   {line}")
        
        return True
    except subprocess.CalledProcessError as e:
        print(f"❌ Preprocessing failed: {e}")
        print(f"Error output: {e.stderr}")
        return False

# Run preprocessing
preprocessing_success = run_preprocessing(converted_id)

if preprocessing_success:
    print("🎉 Preprocessing completed successfully!")
else:
    print("⚠️ Preprocessing had issues, but we can continue")


In [None]:
def verify_setup(dataset_id):
    """Verify that everything is set up correctly."""
    print("🔍 Verifying setup...\n")
    
    # Check environment variables
    print("📍 Environment Variables:")
    required_vars = ['nnUNet_raw', 'nnUNet_preprocessed', 'nnUNet_results']
    all_vars_ok = True
    
    for var in required_vars:
        if var in os.environ:
            print(f"   ✅ {var}: {os.environ[var]}")
        else:
            print(f"   ❌ {var}: Not set")
            all_vars_ok = False
    
    # Check dataset files
    print(f"\n📁 Dataset Files (ID: {dataset_id}):")
    
    raw_path = Path(os.environ['nnUNet_raw'])
    preprocessed_path = Path(os.environ['nnUNet_preprocessed'])
    
    # Find dataset folders
    raw_datasets = list(raw_path.glob(f"Dataset{dataset_id:03d}_*"))
    preprocessed_datasets = list(preprocessed_path.glob(f"Dataset{dataset_id:03d}_*"))
    
    datasets_ok = True
    
    if raw_datasets:
        raw_dataset = raw_datasets[0]
        print(f"   ✅ Raw dataset: {raw_dataset}")
        
        # Check contents
        images_tr = raw_dataset / "imagesTr"
        labels_tr = raw_dataset / "labelsTr"
        dataset_json = raw_dataset / "dataset.json"
        
        if images_tr.exists():
            num_images = len(list(images_tr.glob("*.nii.gz")))
            print(f"      📊 Training images: {num_images}")
        
        if labels_tr.exists():
            num_labels = len(list(labels_tr.glob("*.nii.gz")))
            print(f"      🏷️  Training labels: {num_labels}")
            
        if dataset_json.exists():
            print(f"      📋 Dataset.json: ✅")
    else:
        print(f"   ❌ Raw dataset not found at: {raw_path / f'Dataset{dataset_id:03d}_*'}")
        datasets_ok = False
    
    if preprocessed_datasets:
        preprocessed_dataset = preprocessed_datasets[0]
        print(f"   ✅ Preprocessed dataset: {preprocessed_dataset}")
        
        # Check for plans file
        plans_files = list(preprocessed_dataset.glob("*plans*.json"))
        if plans_files:
            print(f"      📋 Plans file: {plans_files[0].name}")
    else:
        print(f"   ❌ Preprocessed dataset not found at: {preprocessed_path / f'Dataset{dataset_id:03d}_*'}")
        datasets_ok = False
    
    return all_vars_ok and datasets_ok

# Verify setup
setup_ok = verify_setup(converted_id)

if setup_ok:
    print("\n🎉 Setup verification successful! You're ready to train!")
else:
    print("\n⚠️ Some issues found, but you may still be able to proceed")


In [None]:
def print_training_instructions(dataset_id):
    """Print training instructions."""
    print("\n" + "="*60)
    print("🎯 TRAINING INSTRUCTIONS")
    print("="*60)
    print(f"Your dataset (ID: {dataset_id}) is ready for training!\n")
    
    print("📋 Basic Training Commands:")
    print(f"   # Quick 2D training (faster, good for testing)")
    print(f"   !nnUNetv2_train {dataset_id} 2d 0")
    print(f"")
    print(f"   # 3D training (slower, usually better results)")
    print(f"   !nnUNetv2_train {dataset_id} 3d_fullres 0")
    
    print("\n🔍 Advanced Commands:")
    print(f"   # Find best configuration after training multiple models")
    print(f"   !nnUNetv2_find_best_configuration {dataset_id}")
    
    print(f"\n   # Run inference on new data")
    print(f"   !nnUNetv2_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -d {dataset_id} -c 2d -f 0")
    
    print("\n📚 Helpful Commands:")
    print("   !nnUNetv2_train -h                    # Training help")
    print("   !nnUNetv2_predict -h                 # Prediction help")
    print(f"   !ls $nnUNet_preprocessed/Dataset{dataset_id:03d}_*/     # Check preprocessed data")
    
    print("\n💡 Tips:")
    print("   - Start with 2d training for quick testing")
    print("   - 3d_fullres usually gives better results for 3D data")
    print("   - Use fold 0 for quick testing")
    print("   - For full cross-validation, train folds 0,1,2,3,4")
    print("   - Training time depends on dataset size and hardware")
    
    return True

print_training_instructions(converted_id)


In [None]:
# Clean up temporary directory
if 'download_dir' in locals() and download_dir.exists():
    print(f"🧹 Cleaning up temporary directory: {download_dir}")
    shutil.rmtree(download_dir, ignore_errors=True)
    print("✅ Cleanup completed")

print("\n🎉 Dataset setup completed successfully!")
print(f"📊 Dataset ID: {converted_id}")
print("🚀 You can now start training your nnU-Net models!")


In [None]:
def extract_tar_file(tar_path, destination_path, verbose=True):
    """
    Extract a tar file to a specified destination path.
    
    Parameters:
    -----------
    tar_path : str or Path
        Path to the tar file to extract
    destination_path : str or Path  
        Directory where the tar file contents will be extracted
    verbose : bool
        Whether to print progress information
        
    Returns:
    --------
    Path or None
        Path to the extracted directory if successful, None if failed
        
    Example:
    --------
    # Extract a downloaded MSD dataset
    tar_file = "/path/to/Task02_Heart.tar"
    extract_dir = "/path/to/extract/location"
    extracted_folder = extract_tar_file(tar_file, extract_dir)
    
    if extracted_folder:
        print(f"Successfully extracted to: {extracted_folder}")
    """
    import tarfile
    from pathlib import Path
    
    tar_path = Path(tar_path)
    destination_path = Path(destination_path)
    
    # Validate inputs
    if not tar_path.exists():
        print(f"❌ Error: Tar file not found: {tar_path}")
        return None
        
    if not tar_path.suffix in ['.tar', '.gz', '.bz2', '.xz'] and not str(tar_path).endswith('.tar.gz'):
        print(f"⚠️  Warning: File doesn't appear to be a tar archive: {tar_path}")
    
    # Create destination directory if it doesn't exist
    destination_path.mkdir(parents=True, exist_ok=True)
    
    try:
        if verbose:
            print(f"📂 Extracting {tar_path.name} to {destination_path}")
            print(f"   Source: {tar_path}")
            print(f"   Destination: {destination_path}")
        
        # Open and extract the tar file
        with tarfile.open(tar_path, 'r') as tar:
            # Get list of members for progress
            members = tar.getmembers()
            
            if verbose:
                print(f"   📊 Total files to extract: {len(members)}")
            
            # Extract all files
            tar.extractall(path=destination_path)
            
        if verbose:
            print("✅ Extraction completed successfully!")
        
        # Find the extracted folder(s)
        extracted_items = list(destination_path.iterdir())
        
        # If there's exactly one directory, return it
        directories = [item for item in extracted_items if item.is_dir()]
        if len(directories) == 1:
            extracted_folder = directories[0]
            if verbose:
                print(f"📁 Extracted folder: {extracted_folder}")
            return extracted_folder
        elif len(directories) > 1:
            if verbose:
                print(f"📁 Multiple directories extracted:")
                for d in directories:
                    print(f"   - {d}")
            return directories[0]  # Return the first one
        else:
            # No directories, files were extracted directly
            if verbose:
                print(f"📄 Files extracted directly to: {destination_path}")
            return destination_path
            
    except tarfile.TarError as e:
        print(f"❌ Error extracting tar file: {e}")
        return None
    except Exception as e:
        print(f"❌ Unexpected error during extraction: {e}")
        return None

# Example usage:
print("🔧 Tar extraction utility function defined!")
print("📋 Usage example:")
print("   extracted_path = extract_tar_file('/path/to/dataset.tar', '/path/to/extract/to')")
print("   if extracted_path:")
print("       print(f'Dataset extracted to: {extracted_path}')")


In [None]:
# Example: Extract a manually downloaded MSD dataset
# Uncomment and modify the paths below to use with your downloaded tar files

# Example paths - modify these to match your actual file locations
# downloaded_tar = "/path/to/your/downloaded/Task02_Heart.tar"
# extract_location = str(Path.cwd().parent / "datasets" / "manual_extraction")

# Extract the tar file
# extracted_folder = extract_tar_file(downloaded_tar, extract_location)

# if extracted_folder:
#     print(f"\n🎯 Next step: Convert to nnU-Net format")
#     print(f"Run: !nnUNetv2_convert_MSD_dataset -i {extracted_folder}")
# else:
#     print("❌ Extraction failed. Check the file path and try again.")

print("💡 To use this function:")
print("1. Download a dataset tar file manually")
print("2. Uncomment the code above")
print("3. Update the paths to match your downloaded file")
print("4. Run this cell to extract the dataset")


In [None]:
# Uncomment and run to start training
# This will train a 2D model on fold 0 (quick test)

# !nnUNetv2_train {converted_id} 2d 0

print(f"To start training, uncomment the line above and run this cell.")
print(f"Or copy this command to a new cell: !nnUNetv2_train {converted_id} 2d 0")
