# MRI Denoising Training - Google Colab

This notebook sets up the environment and runs the training for the MRI Denoising project.

### 1. Mount Google Drive and Set Paths
Decide whether to work on Google Drive or locally.

In [None]:
from google.colab import drive
import os
import sys

# Mount Drive
drive.mount('/content/drive')

# Decide between Google Drive or Local Local path
USE_DRIVE = True # @param {type:"boolean"}
DRIVE_PATH = "/content/drive/MyDrive/MRI_Training"
LOCAL_PATH = "/content"

if USE_DRIVE:
    WORKDIR = DRIVE_PATH
    os.makedirs(WORKDIR, exist_ok=True)
else:
    WORKDIR = LOCAL_PATH

%cd {WORKDIR}
print(f"Current working directory: {os.getcwd()}")

### 2. Basic Installation (Python & GitHub)
Clone the repository and install dependencies.

In [None]:
REPO_URL = "https://github.com/diegoaeifer/MRI_Denoise.git"
REPO_NAME = "FMImaging_MRI_Denoise"

if not os.path.exists(REPO_NAME):
    print(f"Cloning {REPO_NAME}...")
    !git clone {REPO_URL} {REPO_NAME}
else:
    print(f"{REPO_NAME} exists. Pulling updates...")
    !cd {REPO_NAME} && git pull

# Install requirements
!pip install -r {REPO_NAME}/requirements.txt

### 3. Augmentation Settings
Customize the data augmentation parameters below.

In [None]:
import yaml

# --- Noise Parameters ---
sigma_min = 0.05  # @param {type:"slider", min:0.0, max:0.3, step:0.01}
sigma_max = 0.10  # @param {type:"slider", min:0.0, max:0.3, step:0.01}
noise_grid_size = 4  # @param {type:"integer"}

# --- Probability Parameters ---
flip_prob = 0.5  # @param {type:"slider", min:0.0, max:1.0, step:0.1}
rotate_prob = 0.5  # @param {type:"slider", min:0.0, max:1.0, step:0.1}
affine_prob = 0.2  # @param {type:"slider", min:0.0, max:1.0, step:0.1}
gamma_prob = 0.2  # @param {type:"slider", min:0.0, max:1.0, step:0.1}
anisotropy_prob = 0.1  # @param {type:"slider", min:0.0, max:1.0, step:0.1}
bias_field_prob = 0.1  # @param {type:"slider", min:0.0, max:1.0, step:0.1}
ghosting_prob = 0.02  # @param {type:"slider", min:0.0, max:1.0, step:0.01}
motion_prob = 0.05  # @param {type:"slider", min:0.0, max:1.0, step:0.01}
blur_prob = 0.0  # @param {type:"slider", min:0.0, max:1.0, step:0.1}
spike_prob = 0.0  # @param {type:"slider", min:0.0, max:1.0, step:0.01}

# --- Strength Parameters ---
gamma_min = 0.8 # @param {type:"slider", min:0.5, max:1.0, step:0.05}
gamma_max = 1.2 # @param {type:"slider", min:1.0, max:1.5, step:0.05}
bias_field_coeffs = 0.3 # @param {type:"slider", min:0.0, max:1.0, step:0.1}
anisotropy_downsampling = 1.5 # @param {type:"slider", min:1.0, max:3.0, step:0.1}

aug_override = {
    'data': {
        'augmentation': {
            'sigma_min': sigma_min,
            'sigma_max': sigma_max,
            'noise_grid_size': noise_grid_size,
            'flip_prob': flip_prob,
            'rotate_prob': rotate_prob,
            'affine_prob': affine_prob,
            'gamma_prob': gamma_prob,
            'anisotropy_prob': anisotropy_prob,
            'bias_field_prob': bias_field_prob,
            'ghosting_prob': ghosting_prob,
            'motion_prob': motion_prob,
            'blur_prob': blur_prob,
            'spike_prob': spike_prob,
            'gamma_range': [gamma_min, gamma_max],
            'bias_field_coeffs': bias_field_coeffs,
            'anisotropy_downsampling': anisotropy_downsampling
        }
    }
}

with open('aug_override.yaml', 'w') as f:
    yaml.dump(aug_override, f)

print("Augmentation overrides saved to aug_override.yaml")

### 4. Training Execution
Set your training options here before running.
You can specify custom data and output paths.

In [None]:
import os
# Suppress TensorFlow/TensorBoard warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import yaml
# === TRANING OPTIONS ===
MODEL = "nafnet"  # @param ["nafnet", "drunet", "scunet", "unet"]
CONFIG = "FMImaging_MRI_Denoise/configs/config_nafnet_production.yaml" # @param {type:"string"}
OPTIMIZER = "Adam" # @param ["Adam", "AdamW"]
SCHEDULER = "CosineAnnealing" # @param ["CosineAnnealing", "CosineAnnealing2", "ReduceLROnPlateau", "StepLR"]
LIMIT_DATA = None # @param {type:"raw"}
USE_AUG_OVERRIDE = True # @param {type:"boolean"}

# === PATH OPTIONS ===
# Leave empty to use default config paths
DATA_DIR = "" # @param {type:"string"}
OUTPUT_DIR = "" # @param {type:"string"}
USE_TENSORBOARD = True # @param {type:"boolean"}

# Construct the command
cmd = f"python {REPO_NAME}/src/train.py --model {MODEL} --config {CONFIG}"

if USE_AUG_OVERRIDE:
    print("Merging configuration files...")
    with open(CONFIG, 'r') as f:
        base_conf = yaml.safe_load(f)
    with open('aug_override.yaml', 'r') as f:
        over_conf = yaml.safe_load(f)
    
    # Merge data -> augmentation specifically
    if 'data' not in base_conf:
        base_conf['data'] = {}
    if 'augmentation' not in base_conf['data']:
        base_conf['data']['augmentation'] = {}
    
    base_conf['data']['augmentation'].update(over_conf['data']['augmentation'])
    
    # Merge Optimizer/Scheduler overrides from Notebook
    if 'training' not in base_conf:
        base_conf['training'] = {}
    base_conf['training']['optimizer'] = OPTIMIZER
    base_conf['training']['scheduler'] = SCHEDULER
            
    with open('final_config.yaml', 'w') as f:
        yaml.dump(base_conf, f)
    
    cmd = f"python {REPO_NAME}/src/train.py --model {MODEL} --config final_config.yaml"

if LIMIT_DATA is not None:
    cmd += f" --limit {LIMIT_DATA}"

if DATA_DIR:
    cmd += f" --data_dir \"{DATA_DIR}\""

if OUTPUT_DIR:
    cmd += f" --output_dir \"{OUTPUT_DIR}\""

# Start Tensorboard if requested
if USE_TENSORBOARD:
    %load_ext tensorboard
    log_path = os.path.join(OUTPUT_DIR if OUTPUT_DIR else f"{REPO_NAME}/experiments", "logs")
    os.makedirs(log_path, exist_ok=True)
    %tensorboard --logdir "{log_path}"

print(f"Executing command: {cmd}")
!{cmd}