# üß† Tumor Segmentation Demo
## BME 271D Final Project - Ege, Max, Sasha

**How to use:**
1. Run cells in order (or click Runtime > Run All)
2. Choose to use our sample images OR upload your own
3. See segmentation results!

In [None]:
# ========== SETUP ==========
# Install packages and download code
!pip install -q numpy matplotlib scipy scikit-image pandas

# Download our code from GitHub
!wget -q https://raw.githubusercontent.com/egeozemek/tumor-segmentation/main/tumor_segmentation.py
!wget -q https://raw.githubusercontent.com/egeozemek/tumor-segmentation/main/generate_realistic_tumors.py

# Download sample tumor images
!mkdir -p data/images data/masks
!wget -q -P data/images/ https://raw.githubusercontent.com/egeozemek/tumor-segmentation/main/data/images/tumor_001.png
!wget -q -P data/images/ https://raw.githubusercontent.com/egeozemek/tumor-segmentation/main/data/images/tumor_002.png
!wget -q -P data/images/ https://raw.githubusercontent.com/egeozemek/tumor-segmentation/main/data/images/tumor_003.png
!wget -q -P data/masks/ https://raw.githubusercontent.com/egeozemek/tumor-segmentation/main/data/masks/tumor_001.png
!wget -q -P data/masks/ https://raw.githubusercontent.com/egeozemek/tumor-segmentation/main/data/masks/tumor_002.png
!wget -q -P data/masks/ https://raw.githubusercontent.com/egeozemek/tumor-segmentation/main/data/masks/tumor_003.png

import tumor_segmentation as ts
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image as PILImage
from google.colab import files

# Helper function to load ANY image as grayscale (handles RGB, RGBA, etc.)
def load_image_safe(filepath):
    """Load any image and convert to grayscale float [0,1]"""
    img = PILImage.open(filepath).convert('L')  # Convert to grayscale
    arr = np.array(img).astype(np.float64) / 255.0
    return arr

def load_mask_safe(filepath):
    """Load any image and convert to binary mask"""
    img = PILImage.open(filepath).convert('L')  # Convert to grayscale
    arr = np.array(img)
    return (arr > 127).astype(np.uint8)  # Binary threshold

# Initialize variables
image = None
mask = None

print('‚úÖ Setup complete!')

---
## üìÅ Choose Your Image Source

**Run ONE of the following cells:**
- **Option A:** Upload your own image
- **Option B:** Use our sample images

In [None]:
# ========== OPTION A: UPLOAD YOUR OWN IMAGE ==========
# Run this cell to upload your own tumor image

print('Upload your tumor image (and optionally a mask):')
print('  ‚Ä¢ 1 file = tumor image only (no accuracy metrics)')
print('  ‚Ä¢ 2 files = tumor image + mask (get Dice scores!)\n')

uploaded = files.upload()

if uploaded:
    uploaded_files = list(uploaded.keys())
    
    if len(uploaded_files) == 1:
        # Only image uploaded
        image = load_image_safe(uploaded_files[0])
        mask = None
        print(f'\n‚úÖ Loaded: {uploaded_files[0]}')
        print('   (No mask - will show visual results only)')
        
    elif len(uploaded_files) >= 2:
        # Two files - detect which is which
        img1 = np.array(PILImage.open(uploaded_files[0]).convert('L'))
        img2 = np.array(PILImage.open(uploaded_files[1]).convert('L'))
        
        unique1 = len(np.unique(img1))
        unique2 = len(np.unique(img2))
        
        if unique1 <= 10 and unique2 > 10:
            # File 1 is mask, File 2 is image
            image = load_image_safe(uploaded_files[1])
            mask = load_mask_safe(uploaded_files[0])
            print(f'\n‚úÖ IMAGE: {uploaded_files[1]}')
            print(f'‚úÖ MASK: {uploaded_files[0]}')
        elif unique2 <= 10 and unique1 > 10:
            # File 2 is mask, File 1 is image
            image = load_image_safe(uploaded_files[0])
            mask = load_mask_safe(uploaded_files[1])
            print(f'\n‚úÖ IMAGE: {uploaded_files[0]}')
            print(f'‚úÖ MASK: {uploaded_files[1]}')
        else:
            # Can't detect - assume first is image, second is mask
            image = load_image_safe(uploaded_files[0])
            mask = load_mask_safe(uploaded_files[1])
            print(f'\n‚úÖ IMAGE: {uploaded_files[0]}')
            print(f'‚úÖ MASK: {uploaded_files[1]}')
        
        print('   (Will calculate accuracy metrics!)')
    
    # Show the loaded image
    if mask is not None:
        fig, axes = plt.subplots(1, 2, figsize=(10, 4))
        axes[0].imshow(image, cmap='gray')
        axes[0].set_title('Your Image')
        axes[0].axis('off')
        axes[1].imshow(image, cmap='gray')
        axes[1].imshow(mask, cmap='Reds', alpha=0.5)
        axes[1].set_title('With Mask Overlay')
        axes[1].axis('off')
        plt.show()
    else:
        plt.figure(figsize=(6, 6))
        plt.imshow(image, cmap='gray')
        plt.title('Your Uploaded Image')
        plt.axis('off')
        plt.show()
        
else:
    print('‚ùå No file uploaded. Run Option B instead to use sample images.')

In [None]:
# ========== OPTION B: USE SAMPLE IMAGES ==========
# Run this cell to use our pre-made tumor images

tumor_number = 1  # Change to 1, 2, or 3

tumor_file = f'tumor_{tumor_number:03d}.png'
print(f'Loading {tumor_file}...')

image = load_image_safe(f'data/images/{tumor_file}')
mask = load_mask_safe(f'data/masks/{tumor_file}')

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(image, cmap='gray')
axes[0].set_title(f'Tumor Image {tumor_number}')
axes[0].axis('off')

axes[1].imshow(image, cmap='gray')
axes[1].imshow(mask, cmap='Reds', alpha=0.5)
axes[1].set_title('Ground Truth Mask')
axes[1].axis('off')
plt.show()

print(f'\n‚úÖ Loaded tumor {tumor_number} with ground truth mask')

---
## üìä Analysis

Run the cells below to analyze your image

In [None]:
# ========== FFT ANALYSIS ==========
if image is not None:
    F_shift, mag = ts.compute_fft_spectrum(image)
    fig = ts.visualize_frequency_spectrum(image, F_shift)
    plt.show()
else:
    print('‚ö†Ô∏è No image loaded! Run Option A or B above first.')

In [None]:
# ========== SEGMENTATION ==========
if image is not None:
    params = {'hp_radius': 25, 'bp_r1': 10, 'bp_r2': 40, 'canny_sigma': 1.0, 'gaussian_sigma': 1.0}
    
    if mask is not None:
        # We have ground truth - calculate metrics
        results = ts.run_single_image_experiment(image, mask, params, verbose=True)
        
        print('\n=== RESULTS ===')
        for method, data in results.items():
            print(f"{method}: Dice = {data['metrics']['dice']:.3f}")
    else:
        # No ground truth - just show segmentations
        print('Running segmentation methods...\n')
        
        from skimage.filters import threshold_otsu
        from scipy import ndimage
        
        results = {}
        
        # Baseline Otsu
        thresh = threshold_otsu(image)
        results['Baseline_Raw_Otsu'] = {'mask': image < thresh}
        
        # Smoothed Otsu
        smoothed = ndimage.gaussian_filter(image, sigma=params['gaussian_sigma'])
        thresh = threshold_otsu(smoothed)
        results['Baseline_Smooth_Otsu'] = {'mask': smoothed < thresh}
        
        # FFT methods
        hp_img, _, _ = ts.filter_pipeline(image, 'hp', cutoff_radius=params['hp_radius'])
        thresh = threshold_otsu(hp_img)
        results['FFT_HighPass'] = {'mask': hp_img < thresh}
        
        bp_img, _, _ = ts.filter_pipeline(image, 'bp', r1=params['bp_r1'], r2=params['bp_r2'])
        thresh = threshold_otsu(bp_img)
        results['FFT_BandPass'] = {'mask': bp_img < thresh}
        
        # Canny
        results['Canny_Edges'] = {'mask': ts.canny_segmentation(image, sigma=params['canny_sigma'])}
        
        print('‚úÖ Segmentation complete!')
else:
    print('‚ö†Ô∏è No image loaded! Run Option A or B above first.')

In [None]:
# ========== VISUALIZATION ==========
if image is not None and 'results' in dir():
    masks_dict = {name: data['mask'] for name, data in results.items()}
    
    if mask is not None:
        # Show comparison with ground truth
        fig = ts.plot_segmentation_comparison(image, masks_dict, mask)
    else:
        # Show without ground truth
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.ravel()
        
        axes[0].imshow(image, cmap='gray')
        axes[0].set_title('Original Image', fontweight='bold')
        axes[0].axis('off')
        
        for idx, (method_name, seg_mask) in enumerate(masks_dict.items(), 1):
            axes[idx].imshow(image, cmap='gray')
            axes[idx].imshow(seg_mask, cmap='Reds', alpha=0.6)
            axes[idx].set_title(method_name, fontweight='bold')
            axes[idx].axis('off')
        
        plt.suptitle('Segmentation Method Comparison', fontsize=16, fontweight='bold')
        plt.tight_layout()
    
    plt.show()
    
    if mask is None:
        print('\nüí° Without ground truth, we cannot calculate accuracy metrics.')
        print('   But you can visually compare which method works best!')
else:
    print('‚ö†Ô∏è Run the cells above first!')