# U-Net Training for MRI Reconstruction on Google Colab

This notebook provides a template for training the U-Net model on Google Colab with GPU acceleration.

**Before starting:**
1. Upload your data to Google Drive (Synth_LR_nii/ and HR_nii/ folders)
2. Enable GPU: Runtime → Change runtime type → GPU (T4, V100, or A100)
3. Run cells sequentially

## 1. Check GPU Availability

In [None]:
import torch
import sys

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ WARNING: GPU not available! Go to Runtime → Change runtime type → GPU")

## 2. Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Verify mount
!ls "/content/drive/MyDrive"

## 3. Clone Repository

In [None]:
# Clone repository
!cd /content && git clone https://github.com/marioknicola/synthsup-speechMRI-recon.git

# Change to repository directory
%cd /content/synthsup-speechMRI-recon

# Verify structure
!ls -la

## 4. Install Dependencies

In [None]:
# Install requirements
!pip install -q -r requirements.txt

# Verify key packages
import nibabel as nib
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
print("✅ All packages installed successfully!")

## 5. Setup Data Paths

**IMPORTANT:** Update these paths to match your Google Drive structure

In [None]:
# ⚠️ UPDATE THESE PATHS TO MATCH YOUR DRIVE STRUCTURE
INPUT_DIR = "/content/drive/MyDrive/MRI_Data/Synth_LR_nii"
TARGET_DIR = "/content/drive/MyDrive/MRI_Data/HR_nii"
OUTPUT_DIR = "/content/drive/MyDrive/MRI_Data/outputs"

# Verify data exists
import os
print(f"Input dir exists: {os.path.exists(INPUT_DIR)}")
print(f"Target dir exists: {os.path.exists(TARGET_DIR)}")

if os.path.exists(INPUT_DIR):
    input_files = os.listdir(INPUT_DIR)
    print(f"Found {len(input_files)} input files")
    print(f"Sample files: {input_files[:3]}")
else:
    print("⚠️ Input directory not found! Update INPUT_DIR path.")

if os.path.exists(TARGET_DIR):
    target_files = os.listdir(TARGET_DIR)
    print(f"Found {len(target_files)} target files")
else:
    print("⚠️ Target directory not found! Update TARGET_DIR path.")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory ready: {OUTPUT_DIR}")

## 6. Start Training

### Quick Test (10 epochs)

In [None]:
# Quick test training
!python3 train_unet.py \
    --input-dir "{INPUT_DIR}" \
    --target-dir "{TARGET_DIR}" \
    --output-dir "{OUTPUT_DIR}" \
    --epochs 10 \
    --batch-size 4 \
    --base-filters 32 \
    --lr 1e-4

### Full Training (100 epochs)

**Note:** This will take 4-6 hours on T4 GPU

In [None]:
# Full training
!python3 train_unet.py \
    --input-dir "{INPUT_DIR}" \
    --target-dir "{TARGET_DIR}" \
    --output-dir "{OUTPUT_DIR}" \
    --epochs 100 \
    --batch-size 4 \
    --base-filters 32 \
    --lr 1e-4

## 7. Monitor Training with TensorBoard

In [None]:
# Load TensorBoard extension
%load_ext tensorboard

# Launch TensorBoard
%tensorboard --logdir "{OUTPUT_DIR}/logs"

## 8. Monitor GPU Usage (Optional)

In [None]:
# Check GPU memory usage
!nvidia-smi

## 9. Check Training Results

In [None]:
# List checkpoints
import os
checkpoint_dir = os.path.join(OUTPUT_DIR, "checkpoints")
if os.path.exists(checkpoint_dir):
    checkpoints = sorted(os.listdir(checkpoint_dir))
    print(f"Found {len(checkpoints)} checkpoints:")
    for ckpt in checkpoints:
        size_mb = os.path.getsize(os.path.join(checkpoint_dir, ckpt)) / 1e6
        print(f"  - {ckpt} ({size_mb:.1f} MB)")
else:
    print("No checkpoints found yet.")

# Check test indices
test_indices_path = os.path.join(OUTPUT_DIR, "test_indices.txt")
if os.path.exists(test_indices_path):
    with open(test_indices_path, 'r') as f:
        test_indices = f.read().strip().split(',')
    print(f"\nTest set: {len(test_indices)} samples")
    print(f"Indices: {test_indices[:10]}...")
else:
    print("\nTest indices not saved yet.")

## 10. Download Trained Model (Optional)

In [None]:
# Download best model to local machine
from google.colab import files

best_model_path = os.path.join(OUTPUT_DIR, "checkpoints", "best_model.pth")
if os.path.exists(best_model_path):
    files.download(best_model_path)
    print("✅ Model downloaded!")
else:
    print("⚠️ Best model not found. Training may not be complete.")

## 11. Resume Training (If Disconnected)

In [None]:
# Find latest checkpoint
import glob
checkpoints = glob.glob(os.path.join(OUTPUT_DIR, "checkpoints", "checkpoint_epoch_*.pth"))
if checkpoints:
    latest_checkpoint = max(checkpoints, key=os.path.getctime)
    print(f"Resuming from: {latest_checkpoint}")
    
    # Resume training
    !python3 train_unet.py \
        --input-dir "{INPUT_DIR}" \
        --target-dir "{TARGET_DIR}" \
        --output-dir "{OUTPUT_DIR}" \
        --resume "{latest_checkpoint}" \
        --epochs 100 \
        --batch-size 4
else:
    print("No checkpoints found to resume from.")

## 12. Tips for Long Training Sessions

### Prevent Colab Disconnection

Run this JavaScript in your browser console (F12):

```javascript
function ClickConnect(){
  console.log("Keeping connection alive"); 
  document.querySelector("colab-connect-button").click()
}
setInterval(ClickConnect, 60000)
```

### Save Progress Regularly

The training script automatically saves:
- Best model (lowest validation loss)
- Checkpoints every 10 epochs
- TensorBoard logs continuously

All saved to Google Drive, so you won't lose progress!

### Colab Pro Benefits

- 24h runtime (vs 12h free)
- Faster GPUs (V100, A100)
- More RAM (25GB vs 12GB)
- Background execution

Consider upgrading if training frequently.