# Cross-Validation Training for U-Net MRI Super-Resolution

## Recommended Workflow for 7 Subjects

**Strategy:** 6-fold CV + 1 held-out subject

1. **Hold out 1 subject** (e.g., Subject 0021) for final testing
2. **6-fold CV** on remaining 6 subjects (e.g., 0022-0027)
   - Each fold: Train on 5, test on 1
   - Gives 6 models with independent test performance
3. **Select best fold** based on average test metrics (PSNR/SSIM)
4. **Test best model** on held-out subject (0021)

**Why this approach?**
- ✅ Unbiased model selection (held-out never seen)
- ✅ 6 models provide robust performance estimate
- ✅ Best model gets final test on fresh subject
- ✅ Standard practice for limited subjects

**Before starting:**
1. Prepare data as ZIP files (`Synth_LR_nii.zip` and `HR_nii.zip`)
2. Enable GPU: Runtime → Change runtime type → GPU
3. Upload ZIP files when prompted (Section 2)

**After training:**
- Download results from `/content/cross_validation_results/`
- Files → Left panel → Right-click folder → Download

## 1. Setup Environment

In [None]:
# Check GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ GPU not available! Enable: Runtime → Change runtime type → GPU")

In [None]:
# Clone repository
!cd /content && git clone https://github.com/marioknicola/synthsup-speechMRI-recon.git
%cd /content/synthsup-speechMRI-recon
!pip install -q -r requirements.txt
print("✅ Repository cloned and dependencies installed")

## 2. Upload and Extract Data

Upload your data as ZIP files for faster transfer:

In [None]:
# Upload ZIP files
from google.colab import files
import zipfile
import os

print("="*70)
print("UPLOAD DATA")
print("="*70)
print("Please upload 2 ZIP files:")
print("  1. Synth_LR_nii.zip  → Input/LR images")
print("  2. HR_nii.zip        → Target/HR images")
print("\nClick 'Choose Files' below...")
print("="*70)

uploaded = files.upload()
print(f"\n✅ Uploaded {len(uploaded)} file(s)")

# Extract ZIP files
INPUT_DIR = "/content/data/Synth_LR_nii"
TARGET_DIR = "/content/data/HR_nii"
OUTPUT_DIR = "/content/cross_validation_results"  # Local storage only

os.makedirs("/content/data", exist_ok=True)

print("\nExtracting ZIP files...")
for filename in uploaded.keys():
    print(f"  Extracting {filename}...")
    with zipfile.ZipFile(filename, 'r') as zip_ref:
        zip_ref.extractall("/content/data")
    os.remove(filename)  # Clean up ZIP file
    print(f"    Done!")

# Verify extraction
print("\n" + "="*70)
print("DATA VERIFICATION")
print("="*70)

if os.path.exists(INPUT_DIR):
    input_files = [f for f in os.listdir(INPUT_DIR) if f.endswith('.nii')]
    print(f"✅ Input: {len(input_files)} files")
    print(f"   Location: {INPUT_DIR}")
    print(f"   Sample: {input_files[:3]}")
else:
    print(f"❌ Input directory not found: {INPUT_DIR}")
    print(f"   Make sure Synth_LR_nii folder is in the ZIP")

if os.path.exists(TARGET_DIR):
    target_files = [f for f in os.listdir(TARGET_DIR) if f.endswith('.nii')]
    print(f"✅ Target: {len(target_files)} files")
    print(f"   Location: {TARGET_DIR}")
    print(f"   Sample: {target_files[:3]}")
else:
    print(f"❌ Target directory not found: {TARGET_DIR}")
    print(f"   Make sure HR_nii folder is in the ZIP")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"\n✅ Output directory: {OUTPUT_DIR}")
print(f"   (Stored locally in Colab - download results after training)")
print("="*70)

## 3. Define Subjects and CV Strategy

**For 7 subjects:** 6-fold CV + 1 held-out

In [None]:
# ✅ CONFIGURE YOUR 7 SUBJECTS
ALL_SUBJECTS = ['0021', '0022', '0023', '0024', '0025', '0026', '0027']

# ✅ CHOOSE HELD-OUT SUBJECT (will NOT be used in CV)
HELD_OUT_SUBJECT = '0021'  # Final test subject, never seen during CV

# Subjects for 6-fold CV (excluding held-out)
CV_SUBJECTS = [s for s in ALL_SUBJECTS if s != HELD_OUT_SUBJECT]

print("="*70)
print("EXPERIMENTAL DESIGN: 6-FOLD CV + HELD-OUT SUBJECT")
print("="*70)
print(f"\n📦 Held-out subject: {HELD_OUT_SUBJECT}")
print(f"   → Will be used for FINAL testing only")
print(f"   → Never seen by any model during training")
print(f"\n🔄 CV subjects: {CV_SUBJECTS}")
print(f"   → Used for 6-fold cross-validation")
print(f"   → Each model trains on 5, tests on 1")
print(f"\n📊 Workflow:")
print(f"   1. Train 6 models (6-fold CV on {len(CV_SUBJECTS)} subjects)")
print(f"   2. Select best model based on CV test performance")
print(f"   3. Test best model on held-out subject ({HELD_OUT_SUBJECT})")
print("="*70)

In [None]:
# Generate 6-fold CV on CV subjects (excludes held-out)
cv_folds = []
for i, test_subject in enumerate(CV_SUBJECTS, 1):
    train_subjects = [s for s in CV_SUBJECTS if s != test_subject]
    cv_folds.append({
        'fold_name': f'fold{i}',
        'train': train_subjects,
        'test': [test_subject]
    })

# Display folds
print("\n" + "="*70)
print("6-FOLD CROSS-VALIDATION SPLITS")
print("="*70)
for fold in cv_folds:
    print(f"\n{fold['fold_name']}:")
    print(f"  Train (n=5): {fold['train']}")
    print(f"  Test  (n=1): {fold['test']}")

print(f"\n" + "="*70)
print(f"⚠️  Subject {HELD_OUT_SUBJECT} is NOT in any fold")
print(f"   It will be tested AFTER selecting the best fold")
print("="*70)

## 4. Training Configuration

In [None]:
# ✅ CONFIGURE TRAINING PARAMETERS
TRAINING_CONFIG = {
    'epochs': 200,
    'batch_size': 4,
    'lr': 1e-5,
    'base_filters': 32,
    'loss_alpha': 0.7  # 70% MSE + 30% SSIM
}

print("Training Configuration:")
print("="*40)
for key, value in TRAINING_CONFIG.items():
    print(f"  {key:15s}: {value}")
print("="*40)
print("\n💾 Model Saving: Only best model per fold (saves storage)")

## 5. Train All Folds

This will train all 6 folds sequentially:

In [None]:
# Train all folds
import time

start_time = time.time()

for fold in cv_folds:
    print("\n" + "="*80)
    print(f"TRAINING: {fold['fold_name'].upper()}")
    print("="*80)
    
    # Build command
    cmd = f"python train_cross_validation.py "
    cmd += f"--train-subjects {' '.join(fold['train'])} "
    cmd += f"--test-subjects {' '.join(fold['test'])} "
    cmd += f"--fold-name {fold['fold_name']} "
    cmd += f"--input-dir '{INPUT_DIR}' "
    cmd += f"--target-dir '{TARGET_DIR}' "
    cmd += f"--output-dir '{OUTPUT_DIR}' "
    cmd += f"--epochs {TRAINING_CONFIG['epochs']} "
    cmd += f"--batch-size {TRAINING_CONFIG['batch_size']} "
    cmd += f"--lr {TRAINING_CONFIG['lr']} "
    cmd += f"--base-filters {TRAINING_CONFIG['base_filters']} "
    cmd += f"--loss-alpha {TRAINING_CONFIG['loss_alpha']}"
    
    # Execute
    !{cmd}
    
    print(f"\n✅ Completed {fold['fold_name']}")

total_time = time.time() - start_time
print("\n" + "="*80)
print("ALL FOLDS COMPLETED")
print("="*80)
print(f"Total training time: {total_time/3600:.2f} hours")
print(f"Average per fold: {total_time/len(cv_folds)/3600:.2f} hours")

## 6. Evaluate and Select Best Fold

After all folds are trained, evaluate which performed best:

In [None]:
# Collect test performance from all folds
import json
import numpy as np

fold_results = []

print("="*80)
print("CROSS-VALIDATION RESULTS")
print("="*80)

for fold in cv_folds:
    history_file = os.path.join(OUTPUT_DIR, fold['fold_name'], 'training_history.json')
    
    if os.path.exists(history_file):
        with open(history_file, 'r') as f:
            history = json.load(f)
        
        # Get final epoch performance (could also use best)
        final_train_loss = history['train_loss'][-1]
        final_train_ssim = history['train_ssim'][-1]
        
        # For test performance, run inference (shown in next cell)
        # For now, use training metrics as proxy
        fold_results.append({
            'fold': fold['fold_name'],
            'test_subject': fold['test'][0],
            'train_loss': final_train_loss,
            'train_ssim': final_train_ssim
        })
        
        print(f"\n{fold['fold_name']} (test={fold['test'][0]}):")
        print(f"  Final train loss: {final_train_loss:.6f}")
        print(f"  Final train SSIM: {final_train_ssim:.4f}")
    else:
        print(f"\n{fold['fold_name']}: ❌ No results found")

# Find best fold (lowest loss, or highest SSIM)
if fold_results:
    best_fold = min(fold_results, key=lambda x: x['train_loss'])
    
    print("\n" + "="*80)
    print("BEST FOLD (based on training metrics)")
    print("="*80)
    print(f"Fold: {best_fold['fold']}")
    print(f"Test subject: {best_fold['test_subject']}")
    print(f"Train loss: {best_fold['train_loss']:.6f}")
    print(f"Train SSIM: {best_fold['train_ssim']:.4f}")
    print("\n⚠️  For proper selection, run inference on each fold's test set")
    print("   and select based on test PSNR/SSIM (see next cell)")
    print("="*80)

## 7. Run Inference on Each Fold's Test Set

**Proper model selection requires test set performance:**

In [None]:
# Run inference on each fold's test subject
print("="*80)
print("RUNNING INFERENCE ON TEST SETS")
print("="*80)

for fold in cv_folds:
    fold_name = fold['fold_name']
    test_subject = fold['test'][0]
    
    print(f"\n{fold_name}: Testing on Subject {test_subject}...")
    
    best_model = os.path.join(OUTPUT_DIR, fold_name, f"{fold_name}_best.pth")
    inference_output = os.path.join(OUTPUT_DIR, fold_name, 'test_inference')
    
    if os.path.exists(best_model):
        !python inference_unet.py \
            --checkpoint "{best_model}" \
            --input-dir "{INPUT_DIR}" \
            --target-dir "{TARGET_DIR}" \
            --output-dir "{inference_output}" \
            --file-pattern "*Subject{test_subject}*.nii" \
            --compute-metrics
        
        print(f"✅ {fold_name} inference complete")
    else:
        print(f"❌ Model not found: {best_model}")

print("\n" + "="*80)
print("Check inference output for test set metrics")
print("="*80)

## 8. Test Best Model on Held-Out Subject

After identifying the best fold, test it on the held-out subject:

In [None]:
# ✅ SELECT BEST FOLD (based on test performance from above)
BEST_FOLD_NAME = 'fold1'  # UPDATE THIS after checking test metrics

print("="*80)
print(f"FINAL TEST: {BEST_FOLD_NAME} on Held-Out Subject {HELD_OUT_SUBJECT}")
print("="*80)

best_model = os.path.join(OUTPUT_DIR, BEST_FOLD_NAME, f"{BEST_FOLD_NAME}_best.pth")
final_test_output = os.path.join(OUTPUT_DIR, 'final_test_heldout')

if os.path.exists(best_model):
    !python inference_unet.py \
        --checkpoint "{best_model}" \
        --input-dir "{INPUT_DIR}" \
        --target-dir "{TARGET_DIR}" \
        --output-dir "{final_test_output}" \
        --file-pattern "*Subject{HELD_OUT_SUBJECT}*.nii" \
        --compute-metrics \
        --visualize
    
    print("\n" + "="*80)
    print("✅ FINAL TEST COMPLETE")
    print("="*80)
    print(f"Best model: {BEST_FOLD_NAME}")
    print(f"Held-out subject: {HELD_OUT_SUBJECT}")
    print(f"Results saved to: {final_test_output}")
    print("="*80)
else:
    print(f"❌ Model not found: {best_model}")

## 9. Summary and Reporting

**For your abstract, report:**

1. **CV Performance (6 folds):**
   - Mean ± SD of test metrics across 6 folds
   - Example: "PSNR = 28.5 ± 1.2 dB (6-fold CV)"

2. **Held-Out Performance:**
   - Best model's performance on Subject 0021
   - Example: "PSNR = 27.8 dB on held-out subject"

3. **Methods:**
   - "6-fold cross-validation on 6 subjects (training: n=5, testing: n=1 per fold)"
   - "Best model selected based on CV performance"
   - "Final evaluation on held-out subject (never seen during training)"

## Tips

### Model Selection:
- **Don't** use training loss to select best fold
- **Do** use test set inference metrics (PSNR/SSIM)
- **Option 1:** Highest average test PSNR
- **Option 2:** Highest average test SSIM
- **Option 3:** Best average rank across metrics

### Prevent Disconnection:
```javascript
// Run in browser console (F12)
setInterval(() => {
    document.querySelector("colab-connect-button").click()
}, 60000)
```

### Memory Management:
- Clear GPU between folds: `torch.cuda.empty_cache()`
- Reduce batch size if OOM errors
- Monitor with `!nvidia-smi`

## 📥 Download Results

After training completes, download your results to continue locally:

In [None]:
# Download all results as a ZIP file
import shutil

print("="*70)
print("PREPARING RESULTS FOR DOWNLOAD")
print("="*70)

# Create ZIP of all results
output_zip = "/content/cross_validation_results.zip"
print(f"\nCreating ZIP archive...")
shutil.make_archive("/content/cross_validation_results", 'zip', OUTPUT_DIR)

# Get size
import os
size_mb = os.path.getsize(output_zip) / (1024 * 1024)
print(f"✅ Archive created: {output_zip} ({size_mb:.1f} MB)")

print("\n" + "="*70)
print("DOWNLOAD OPTIONS")
print("="*70)
print("\n1️⃣ Download ZIP file:")
print("   Run the next cell to download the ZIP")
print("\n2️⃣ Download individual folders:")
print("   Files panel (left) → /content/cross_validation_results/")
print("   Right-click any folder → Download")
print("\n3️⃣ What's included:")
print("   • fold1/, fold2/, ... fold6/ (best models + configs)")
print("   • Each fold: ~30-50 MB")
print("="*70)

In [None]:
# Download the ZIP file
from google.colab import files

print("Downloading cross_validation_results.zip...")
print("This may take a few minutes depending on the size...")
files.download(output_zip)
print("✅ Download complete!")

## 🖥️ Continue Locally

After downloading, extract and evaluate on your local machine:

```bash
# Extract the ZIP
unzip cross_validation_results.zip -d ./cv_models

# Batch evaluate all folds
cd synthsup-speechMRI-recon
python utils/evaluate_all_folds.py \
    --models-dir ./cv_models \
    --input-dir ../Synth_LR_nii \
    --target-dir ../HR_nii \
    --output-dir ./evaluation_results

# Test best model on held-out subject
python inference_unet.py \
    --checkpoint ./cv_models/fold2/fold2_best.pth \
    --input-dir ../Synth_LR_nii \
    --target-dir ../HR_nii \
    --output-dir ./inference_results/heldout \
    --file-pattern "*Subject0021*.nii" \
    --compute-metrics --visualize
```

See `docs/COLAB_TO_LOCAL_WORKFLOW.md` for detailed instructions.