# Polyner 2SOD Symmetric Geometry Testing Pipeline

This notebook tests the Polyner metal artifact reduction on the 2SOD symmetric geometry dataset where SDD = 2 × SOD.

**Dataset Contents:**
- `slice42` - original slice
- `sino42` - sinogram  
- `rec42` - reconstructed image for verification
- MATLAB code for fan-beam projection
- Config file with geometry parameters

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install ninja git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
!pip install SimpleITK tqdm numpy scipy scikit-image commentjson

# Clone Polyner repository
!git clone https://github.com/your-repo/PolynerCode.git
%cd PolynerCode/Polyner

# Switch to 2SOD-geometry branch
!git checkout 2SOD-geometry

## 2. Data Download and Setup

In [None]:
import os
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
from pathlib import Path
import json

# Create directories
os.makedirs('input_2SOD', exist_ok=True)
os.makedirs('output_2SOD', exist_ok=True)
os.makedirs('model_2SOD', exist_ok=True)

print("✅ Directories created")
print("📁 Please upload the 2SOD dataset files to the input_2SOD folder:")
print("   - slice42_216_216.bin")
print("   - sino42_400_360.bin") 
print("   - rec42_216_216.bin")
print("   - Any MATLAB files")

## 3. Data Conversion and Preparation

In [None]:
def load_binary_data(file_path, shape, dtype=np.float32):
    """Load binary data and reshape."""
    data = np.fromfile(file_path, dtype=dtype)
    return data.reshape(shape)

def save_as_nifti(data, output_path):
    """Save numpy array as NIFTI file."""
    if data.ndim == 2:
        data = data.astype(np.float32)
    img = sitk.GetImageFromArray(data)
    sitk.WriteImage(img, str(output_path))
    return img

# Load and convert 2SOD dataset
print("Loading 2SOD dataset...")

# Load original slice (216x216)
slice42_path = 'input_2SOD/slice42_216_216.bin'
if os.path.exists(slice42_path):
    slice42 = load_binary_data(slice42_path, (216, 216))
    save_as_nifti(slice42, 'input_2SOD/gt_0.nii')
    print(f"✅ Original slice loaded: {slice42.shape}, range: [{np.min(slice42):.3f}, {np.max(slice42):.3f}]")
else:
    print("❌ slice42_216_216.bin not found")

# Load sinogram (400x360)
sino42_path = 'input_2SOD/sino42_400_360.bin'
if os.path.exists(sino42_path):
    sino42 = load_binary_data(sino42_path, (360, 400))  # (angles, detectors)
    save_as_nifti(sino42, 'input_2SOD/ma_sinogram_0.nii')
    print(f"✅ Sinogram loaded: {sino42.shape}, range: [{np.min(sino42):.3f}, {np.max(sino42):.3f}]")
else:
    print("❌ sino42_400_360.bin not found")

# Load reconstruction for verification (216x216)
rec42_path = 'input_2SOD/rec42_216_216.bin'
if os.path.exists(rec42_path):
    rec42 = load_binary_data(rec42_path, (216, 216))
    save_as_nifti(rec42, 'input_2SOD/ma_0.nii')
    print(f"✅ Reconstruction loaded: {rec42.shape}, range: [{np.min(rec42):.3f}, {np.max(rec42):.3f}]")
else:
    print("❌ rec42_216_216.bin not found")

## 4. 2SOD Geometry Configuration

In [None]:
# Create 2SOD symmetric geometry configuration
config_2SOD = {
    "file": {
        "in_dir": "./input_2SOD",
        "model_dir": "./model_2SOD", 
        "out_dir": "./output_2SOD",
        "voxel_size": 1.0,
        "SOD": 410,           # Source-to-object distance
        "SDD": 820,           # Source-to-detector distance (2 × SOD)
        "detector_geometry": "arc",  # Use arc for symmetric geometry
        "geometry_type": "symmetric_fan_beam",
        "h": 216,
        "w": 216
    },
    "train": {
        "gpu": 0,
        "lr": 0.001,
        "epoch": 2000,
        "save_epoch": 500,
        "num_sample_ray": 2,
        "lr_decay_epoch": 1000,
        "lr_decay_coefficient": 0.1,
        "batch_size": 40,
        "lambda": 0.2
    },
    "encoding": {
        "otype": "Grid",
        "type": "Hash",
        "n_levels": 16,
        "n_features_per_level": 8,
        "log2_hashmap_size": 19,
        "base_resolution": 2,
        "per_level_scale": 2,
        "interpolation": "Linear"
    },
    "network": {
        "otype": "FullyFusedMLP",
        "activation": "ReLU", 
        "output_activation": "Squareplus",
        "n_neurons": 128,
        "n_hidden_layers": 2
    }
}

# Save configuration
with open('config_2SOD.json', 'w') as f:
    json.dump(config_2SOD, f, indent=4)

print("✅ 2SOD symmetric geometry configuration created")
print(f"   SOD: {config_2SOD['file']['SOD']} mm")
print(f"   SDD: {config_2SOD['file']['SDD']} mm (2 × SOD)")
print(f"   Geometry: {config_2SOD['file']['geometry_type']}")

## 5. Fan Sensor Position Generation

In [None]:
# Generate fan sensor positions for 2SOD geometry
def generate_2SOD_fan_positions(num_detectors=400, SOD=410):
    """Generate fan sensor positions for symmetric 2SOD geometry."""
    # For symmetric geometry with SDD = 2*SOD, use arc detector
    # Detector spans from -gamma_max to +gamma_max
    gamma_max = np.arctan(108 / SOD)  # Approximate max fan angle
    
    # Generate detector angles
    detector_angles = np.linspace(-gamma_max, gamma_max, num_detectors)
    
    # Convert to positions (arc geometry)
    SDD = 2 * SOD  # 820 mm
    detector_positions = SDD * np.tan(detector_angles)
    
    return detector_positions.astype(np.float32)

# Generate and save fan sensor positions
fan_positions = generate_2SOD_fan_positions(400, 410)
fanSensorPos = fan_positions.reshape(-1, 1)
fanPos_sitk = sitk.GetImageFromArray(fanSensorPos)
sitk.WriteImage(fanPos_sitk, 'input_2SOD/fanSensorPos.nii')

print(f"✅ Fan sensor positions generated:")
print(f"   Shape: {fanSensorPos.shape}")
print(f"   Range: {np.min(fan_positions):.1f} to {np.max(fan_positions):.1f} mm")
print(f"   Detector spacing: {(np.max(fan_positions) - np.min(fan_positions))/399:.3f} mm")

## 6. Create Metal Mask (Placeholder)

In [None]:
# Create a simple metal mask for testing
# This should be replaced with actual metal segmentation

if 'slice42' in locals():
    # Create a simple threshold-based mask
    # Adjust threshold based on your data characteristics
    threshold = np.percentile(slice42, 95)  # Top 5% intensity values
    metal_mask = (slice42 > threshold).astype(np.float32)
    
    # Save mask
    save_as_nifti(metal_mask, 'input_2SOD/mask_0.nii')
    
    print(f"✅ Metal mask created:")
    print(f"   Threshold: {threshold:.3f}")
    print(f"   Metal pixels: {np.sum(metal_mask)}")
    print(f"   Metal percentage: {100*np.sum(metal_mask)/metal_mask.size:.1f}%")
else:
    print("❌ Cannot create mask - slice42 not loaded")

## 7. Copy X-ray Spectrum

In [None]:
# Copy spectrum file for polychromatic modeling
import shutil

spectrum_files = ['GE14Spectrum120KVP.mat', 'spectrum_UNC.mat']
spectrum_copied = False

for spectrum_file in spectrum_files:
    source_path = f'input/{spectrum_file}'
    if os.path.exists(source_path):
        dest_path = f'input_2SOD/{spectrum_file}'
        shutil.copy2(source_path, dest_path)
        print(f"✅ Copied {spectrum_file}")
        spectrum_copied = True
        break

if not spectrum_copied:
    print("❌ No spectrum file found - you may need to upload one")

## 8. Data Visualization

In [None]:
# Visualize the 2SOD dataset
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

if 'slice42' in locals():
    axes[0,0].imshow(slice42, cmap='gray')
    axes[0,0].set_title('Original Slice42')
    axes[0,0].axis('off')

if 'sino42' in locals():
    axes[0,1].imshow(sino42, cmap='gray', aspect='auto')
    axes[0,1].set_title('Sinogram (360×400)')
    axes[0,1].set_xlabel('Detector')
    axes[0,1].set_ylabel('Angle')

if 'rec42' in locals():
    axes[1,0].imshow(rec42, cmap='gray')
    axes[1,0].set_title('Reconstruction (Verification)')
    axes[1,0].axis('off')

if 'metal_mask' in locals():
    axes[1,1].imshow(metal_mask, cmap='hot')
    axes[1,1].set_title('Metal Mask')
    axes[1,1].axis('off')

plt.tight_layout()
plt.show()

# Plot fan sensor positions
if 'fan_positions' in locals():
    plt.figure(figsize=(10, 3))
    plt.plot(fan_positions, 'b.-', markersize=2)
    plt.title('2SOD Fan Sensor Positions')
    plt.xlabel('Detector Index')
    plt.ylabel('Position (mm)')
    plt.grid(True, alpha=0.3)
    plt.show()

## 9. Training Pipeline

In [None]:
# Train Polyner on 2SOD dataset
print("🚀 Starting Polyner training on 2SOD symmetric geometry dataset...")

# Check if all required files exist
required_files = [
    'input_2SOD/gt_0.nii',
    'input_2SOD/ma_sinogram_0.nii',
    'input_2SOD/mask_0.nii',
    'input_2SOD/fanSensorPos.nii',
    'config_2SOD.json'
]

all_files_exist = True
for file_path in required_files:
    if os.path.exists(file_path):
        print(f"✅ {file_path}")
    else:
        print(f"❌ {file_path} - MISSING")
        all_files_exist = False

if all_files_exist:
    print("\n🎯 All files ready - starting training...")
    # Run training
    !python main.py --config config_2SOD.json --img_id 0
else:
    print("\n⚠️  Missing files - please ensure all data is uploaded")

## 10. Results Analysis and Verification

In [None]:
# Load and compare results
if os.path.exists('output_2SOD/polyner_0.nii'):
    # Load Polyner result
    polyner_result = sitk.GetArrayFromImage(sitk.ReadImage('output_2SOD/polyner_0.nii'))
    
    # Compare with verification reconstruction
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Original slice
    if 'slice42' in locals():
        axes[0,0].imshow(slice42, cmap='gray')
        axes[0,0].set_title('Original Slice')
        axes[0,0].axis('off')
    
    # Verification reconstruction
    if 'rec42' in locals():
        axes[0,1].imshow(rec42, cmap='gray')
        axes[0,1].set_title('Verification Reconstruction')
        axes[0,1].axis('off')
    
    # Polyner result
    axes[0,2].imshow(polyner_result, cmap='gray')
    axes[0,2].set_title('Polyner Result')
    axes[0,2].axis('off')
    
    # Difference maps
    if 'slice42' in locals():
        diff_orig = np.abs(polyner_result - slice42)
        axes[1,0].imshow(diff_orig, cmap='hot')
        axes[1,0].set_title('|Polyner - Original|')
        axes[1,0].axis('off')
    
    if 'rec42' in locals():
        diff_rec = np.abs(polyner_result - rec42)
        axes[1,1].imshow(diff_rec, cmap='hot')
        axes[1,1].set_title('|Polyner - Verification|')
        axes[1,1].axis('off')
    
    # Profile comparison
    center_row = polyner_result.shape[0] // 2
    axes[1,2].plot(polyner_result[center_row, :], 'b-', label='Polyner', linewidth=2)
    if 'slice42' in locals():
        axes[1,2].plot(slice42[center_row, :], 'g--', label='Original', alpha=0.7)
    if 'rec42' in locals():
        axes[1,2].plot(rec42[center_row, :], 'r:', label='Verification', alpha=0.7)
    axes[1,2].set_title('Center Row Profile')
    axes[1,2].legend()
    axes[1,2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Calculate metrics
    if 'slice42' in locals():
        from skimage.metrics import peak_signal_noise_ratio, structural_similarity
        
        psnr_orig = peak_signal_noise_ratio(slice42, polyner_result, data_range=slice42.max()-slice42.min())
        ssim_orig = structural_similarity(slice42, polyner_result, data_range=slice42.max()-slice42.min())
        
        print(f"📊 Polyner vs Original:")
        print(f"   PSNR: {psnr_orig:.2f} dB")
        print(f"   SSIM: {ssim_orig:.4f}")
    
    if 'rec42' in locals():
        psnr_rec = peak_signal_noise_ratio(rec42, polyner_result, data_range=rec42.max()-rec42.min())
        ssim_rec = structural_similarity(rec42, polyner_result, data_range=rec42.max()-rec42.min())
        
        print(f"📊 Polyner vs Verification:")
        print(f"   PSNR: {psnr_rec:.2f} dB")
        print(f"   SSIM: {ssim_rec:.4f}")

else:
    print("❌ No Polyner results found - training may not have completed")

## 11. Export Results

In [None]:
# Package results for download
!zip -r 2SOD_results.zip output_2SOD/ model_2SOD/ config_2SOD.json

print("✅ Results packaged in 2SOD_results.zip")
print("📁 Contents:")
print("   - output_2SOD/ (Polyner reconstructions)")
print("   - model_2SOD/ (Trained models)")
print("   - config_2SOD.json (Configuration used)")