# Iteration Control for Deconvolution

[Colab Link](https://colab.research.google.com/github/casangi/astroviper/blob/main/docs/core_tutorials/imaging/iteration_control_demo.ipynb)

This notebook demonstrates the **iteration control module** for managing major/minor cycle workflows in deconvolution algorithms. The module provides adaptive threshold calculation, convergence checking, and iteration count management compatible with CASA's behavior.

---

## Assumptions and Background

**Deconvolution Iteration Control:**

- **Major Cycle:** One complete iteration involving gridding visibilities, computing residual images (FFT), and running minor cycle deconvolution. Major cycles are expensive (I/O intensive).
  
- **Minor Cycle:** Multiple fast CLEAN iterations operating on image plane only, typically 10-1000 iterations per major cycle.

- **Adaptive Cyclethreshold:** Prevents cleaning too deeply into PSF sidelobes during a single minor cycle:
  ```
  psf_fraction = cyclefactor × max_psf_sidelobe
  psf_fraction = clamp(psf_fraction, minpsffraction, maxpsffraction)
  cyclethreshold = max(psf_fraction × peak_residual, threshold)
  ```

- **Convergence Criteria:** Deconvolution stops when:
  1. Maximum iterations exhausted (`niter <= 0`)
  2. Peak residual below threshold (`peak_residual <= threshold`)
  3. Maximum major cycles reached (`nmajor == 0`)
  4. No valid pixels remaining (`masksum == 0`)

- **ReturnDict:** A data structure holding deconvolution results (peak residual, iterations done, mask statistics, PSF info) indexed by time, polarization, and channel coordinates.

---
## Pseudo Code

```
# Initialize controller
controller = IterationController(niter, nmajor, threshold, ...)

# Major cycle loop
while controller.stopcode == 0:
    # Calculate adaptive controls for next minor cycle
    cycleniter, cyclethreshold = controller.calculate_cycle_controls(return_dict)
    
    # Run deconvolution (minor cycle)
    return_dict = deconvolve(dirty, psf, niter=cycleniter, threshold=cyclethreshold)
    
    # Update counts
    controller.update_counts(return_dict)
    
    # Check convergence
    stopcode, description = controller.check_convergence(return_dict)
```

---
## API

In [1]:
from astroviper.core.imaging.imaging_utils.iteration_control import IterationController
IterationController?

[31mInit signature:[39m
IterationController(
    niter: int = [32m1000[39m,
    nmajor: int = -[32m1[39m,
    threshold: float = [32m0.0[39m,
    gain: float = [32m0.1[39m,
    cyclefactor: float = [32m1.0[39m,
    minpsffraction: float = [32m0.05[39m,
    maxpsffraction: float = [32m0.8[39m,
    cycleniter: int = -[32m1[39m,
    nsigma: float = [32m0.0[39m,
)
[31mDocstring:[39m     
Manages iteration control logic for deconvolution algorithms.

The controller extracts needed statistics (peak residual, masksum, iterations
done, etc.) from ReturnDict internally, so callers don't need to manually
pass individual values.

Attributes:
-----------
niter : int
    Maximum number of minor cycle iterations remaining
nmajor : int
    Maximum number of major cycles remaining (-1 = unlimited)
threshold : float
    Global stopping threshold (in Jy or image units)
gain : float
    CLEAN loop gain (typically 0.1)
cyclefactor : float
    Multiplier for PSF sidelobe to set cyclet

---
## Notes

- This module closely follows CASA's iteration control logic (from `_gclean.py` and `imager_return_dict.py`)


## Example 1: Basic Iteration Control with Convergence

Demonstrate basic iteration control workflow with a simple mock deconvolution that converges by reaching the threshold.

### Setup Mock Data and Helper Functions

In [2]:
import numpy as np
from collections import OrderedDict

# Import iteration control module
# Note: Adjust path as needed for your environment
import sys
sys.path.insert(0, '.')
from astroviper.core.imaging.imaging_utils.iteration_control import (
    IterationController,
    ReturnDict,
    merge_return_dicts,
    get_peak_residual_from_returndict,
    get_iterations_done_from_returndict,
)

def mock_deconvolve(initial_peak, niter, threshold, gain=0.1):
    """
    Mock deconvolution that simulates cleaning behavior.
    Peak residual decreases by gain factor each iteration.
    """
    peak_residual = initial_peak
    iterations_done = 0
    
    for i in range(niter):
        if peak_residual <= threshold:
            break
        peak_residual *= (1 - gain)  # Residual decreases
        iterations_done += 1
    
    return peak_residual, iterations_done

def create_return_dict(peak_residual, iterations_done, masksum=100, max_psf_sidelobe=0.2,
                      time=0, pol=0, chan=0):
    """
    Create a ReturnDict with deconvolution results.
    """
    rd = ReturnDict()
    rd.add({
        'peakres': peak_residual,
        'peakres_nomask': peak_residual,
        'masksum': masksum,
        'iter_done': iterations_done,
        'max_psf_sidelobe': max_psf_sidelobe,
        'niter': iterations_done,
        'threshold': 0.0,
        'loop_gain': 0.1,
    }, time=time, pol=pol, chan=chan)
    return rd

### Run Basic Iteration Control Loop

In [3]:
# Initialize controller
controller = IterationController(
    niter=1000,
    nmajor=10,
    threshold=0.01,  # Stop when residual drops below 0.01 Jy
    gain=0.1,
    cyclefactor=1.5,
    cycleniter=100,  # Max 100 iterations per minor cycle
)

print("Initial Controller State:")
print(f"  niter: {controller.niter}")
print(f"  nmajor: {controller.nmajor}")
print(f"  threshold: {controller.threshold} Jy")
print(f"  cycleniter: {controller.cycleniter}")
print(f"  cyclefactor: {controller.cyclefactor}")
print()

# Simulate major cycle loop
current_peak = 1.0  # Starting peak residual: 1 Jy
major_cycle = 0

print("Starting major cycle loop...\n")

while controller.stopcode.major == 0:
    major_cycle += 1
    print(f"=== Major Cycle {major_cycle} ===")
    
    # Create mock return_dict with current state
    current_rd = create_return_dict(
        peak_residual=current_peak,
        iterations_done=0,  # Will be updated after deconvolution
        masksum=100,
        max_psf_sidelobe=0.2
    )
    
    # Calculate adaptive cycle controls
    cycleniter, cyclethresh = controller.calculate_cycle_controls(current_rd)
    print(f"  Cycle controls: niter={cycleniter}, threshold={cyclethresh:.4f} Jy")
    
    # Run mock deconvolution
    new_peak, iters_done = mock_deconvolve(
        initial_peak=current_peak,
        niter=cycleniter,
        threshold=cyclethresh,
        gain=controller.gain
    )
    
    print(f"  Deconvolution: {iters_done} iterations, peak: {current_peak:.6f} -> {new_peak:.6f} Jy")
    
    # Update current peak
    current_peak = new_peak
    
    # Create return_dict with results
    result_rd = create_return_dict(
        peak_residual=new_peak,
        iterations_done=iters_done,
        masksum=100,
        max_psf_sidelobe=0.2
    )
    
    # Update iteration counts
    controller.update_counts(result_rd)
    print(f"  Updated counts: niter remaining={controller.niter}, nmajor remaining={controller.nmajor}")
    print(f"                  total iterations done={controller.total_iter_done}")
    
    # Check convergence
    stopcode, stopdesc = controller.check_convergence(result_rd)
    
    if stopcode != 0:
        print(f"\n  *** CONVERGED: {stopdesc} ***")
        print(f"  Final stopcode: {stopcode}")
    print()

print("\n=== Final State ===")
print(f"Total major cycles: {controller.major_done}")
print(f"Total iterations: {controller.total_iter_done}")
print(f"Final peak residual: {current_peak:.6f} Jy")
print(f"Stop code: {controller.stopcode} - {controller.stopdescription}")

Initial Controller State:
  niter: 1000
  nmajor: 10
  threshold: 0.01 Jy
  cycleniter: 100
  cyclefactor: 1.5

Starting major cycle loop...

=== Major Cycle 1 ===
  Cycle controls: niter=100, threshold=0.3000 Jy
  Deconvolution: 12 iterations, peak: 1.000000 -> 0.282430 Jy
  Updated counts: niter remaining=988, nmajor remaining=9
                  total iterations done=12

  *** CONVERGED: Continue iterations ***
  Final stopcode: StopCode(major=0, minor=0)

=== Major Cycle 2 ===
  Cycle controls: niter=100, threshold=0.0847 Jy
  Deconvolution: 12 iterations, peak: 0.282430 -> 0.079766 Jy
  Updated counts: niter remaining=976, nmajor remaining=8
                  total iterations done=24

  *** CONVERGED: Continue iterations ***
  Final stopcode: StopCode(major=0, minor=0)

=== Major Cycle 3 ===
  Cycle controls: niter=100, threshold=0.0239 Jy
  Deconvolution: 12 iterations, peak: 0.079766 -> 0.022528 Jy
  Updated counts: niter remaining=964, nmajor remaining=7
                  total

## Example 2: Different Convergence Scenarios

Demonstrate different stopping conditions by varying initial parameters.

### Scenario A: Iteration Limit (stopcode 1)

In [4]:
# Controller with low niter - will exhaust iterations before reaching threshold
controller_A = IterationController(
    niter=50,  # Only 50 iterations total
    nmajor=10,
    threshold=0.001,  # Very low threshold (hard to reach)
    gain=0.1,
    cyclefactor=1.5,
    cycleniter=20,
)

current_peak = 1.0
print("Scenario A: Exhaust iteration limit\n")

while controller_A.stopcode.major == 0:
    current_rd = create_return_dict(current_peak, 0, 100, 0.2)
    cycleniter, cyclethresh = controller_A.calculate_cycle_controls(current_rd)
    
    new_peak, iters_done = mock_deconvolve(current_peak, cycleniter, cyclethresh, 0.1)
    current_peak = new_peak
    
    result_rd = create_return_dict(new_peak, iters_done, 100, 0.2)
    controller_A.update_counts(result_rd)
    controller_A.check_convergence(result_rd)

print(f"Result: {controller_A.stopdescription}")
print(f"Stopcode: {controller_A.stopcode}")
print(f"Total iterations: {controller_A.total_iter_done}")
print(f"Final peak: {current_peak:.6f} Jy\n")

Scenario A: Exhaust iteration limit

Result: Reached the iteration limit
Stopcode: StopCode(major=1, minor=0)
Total iterations: 50
Final peak: 0.005154 Jy



### Scenario B: Major Cycle Limit (stopcode 9)

In [5]:
# Controller with low nmajor - will exhaust major cycles
controller_B = IterationController(
    niter=1000,
    nmajor=3,  # Only 3 major cycles
    threshold=0.001,
    gain=0.1,
    cyclefactor=1.5,
    cycleniter=20,
)

current_peak = 1.0
print("Scenario B: Exhaust major cycle limit\n")

while controller_B.stopcode.major == 0:
    current_rd = create_return_dict(current_peak, 0, 100, 0.2)
    cycleniter, cyclethresh = controller_B.calculate_cycle_controls(current_rd)
    
    new_peak, iters_done = mock_deconvolve(current_peak, cycleniter, cyclethresh, 0.1)
    current_peak = new_peak
    
    result_rd = create_return_dict(new_peak, iters_done, 100, 0.2)
    controller_B.update_counts(result_rd)
    controller_B.check_convergence(result_rd)

print(f"Result: {controller_B.stopdescription}")
print(f"Stopcode: {controller_B.stopcode}")
print(f"Major cycles done: {controller_B.major_done}")
print(f"Total iterations: {controller_B.total_iter_done}")
print(f"Final peak: {current_peak:.6f} Jy\n")

Scenario B: Exhaust major cycle limit

Result: Reached the major cycle limit (nmajor)
Stopcode: StopCode(major=9, minor=0)
Major cycles done: 3
Total iterations: 36
Final peak: 0.022528 Jy



### Scenario C: Zero Mask (stopcode 7)

In [6]:
# Simulate mask becoming empty (no valid pixels to clean)
controller_C = IterationController(
    niter=1000,
    nmajor=10,
    threshold=0.001,
    gain=0.1,
    cyclefactor=1.5,
    cycleniter=20,
)

current_peak = 1.0
print("Scenario C: Zero mask (no valid pixels)\n")

# Create return_dict with masksum=0
result_rd = create_return_dict(current_peak, 10, masksum=0, max_psf_sidelobe=0.2)
controller_C.check_convergence(result_rd)

print(f"Result: {controller_C.stopdescription}")
print(f"Stopcode: {controller_C.stopcode}\n")

Scenario C: Zero mask (no valid pixels)

Result: Zero mask
Stopcode: StopCode(major=7, minor=0)



## Example 3: Interactive Parameter Updates

Demonstrate continuing deconvolution with updated parameters after initial convergence.

In [7]:
# Initialize controller with modest parameters
controller_interactive = IterationController(
    niter=100,
    nmajor=5,
    threshold=0.05,  # Relatively high threshold
    gain=0.1,
    cyclefactor=1.5,
    cycleniter=30,
)

print("=== First Deconvolution Run ===")
print(f"Initial parameters: niter={controller_interactive.niter}, threshold={controller_interactive.threshold} Jy\n")

current_peak = 1.0
run = 1

# First run
while controller_interactive.stopcode.major == 0:
    current_rd = create_return_dict(current_peak, 0, 100, 0.2)
    cycleniter, cyclethresh = controller_interactive.calculate_cycle_controls(current_rd)
    
    new_peak, iters_done = mock_deconvolve(current_peak, cycleniter, cyclethresh, 0.1)
    current_peak = new_peak
    
    result_rd = create_return_dict(new_peak, iters_done, 100, 0.2)
    controller_interactive.update_counts(result_rd)
    controller_interactive.check_convergence(result_rd)

print(f"First run completed: {controller_interactive.stopdescription}")
print(f"Peak residual: {current_peak:.6f} Jy")
print(f"Total iterations: {controller_interactive.total_iter_done}\n")

# User decides to continue with updated parameters
print("=== Updating Parameters and Continuing ===")
code, msg = controller_interactive.update_parameters(
    niter=200,  # Add 200 more iterations
    threshold="10mJy",  # Lower threshold (0.01 Jy)
    nmajor=5,  # Add 5 more major cycles
)

if code == 0:
    print(f"Parameters updated successfully:")
    print(f"  niter: {controller_interactive.niter}")
    print(f"  threshold: {controller_interactive.threshold} Jy")
    print(f"  nmajor: {controller_interactive.nmajor}\n")
    
    # Reset stopcode to continue
    controller_interactive.reset_stopcode()
    
    print("Continuing deconvolution...\n")
    
    # Second run
    while controller_interactive.stopcode.major == 0:
        current_rd = create_return_dict(current_peak, 0, 100, 0.2)
        cycleniter, cyclethresh = controller_interactive.calculate_cycle_controls(current_rd)
        
        new_peak, iters_done = mock_deconvolve(current_peak, cycleniter, cyclethresh, 0.1)
        current_peak = new_peak
        
        result_rd = create_return_dict(new_peak, iters_done, 100, 0.2)
        controller_interactive.update_counts(result_rd)
        controller_interactive.check_convergence(result_rd)
    
    print(f"Second run completed: {controller_interactive.stopdescription}")
    print(f"Final peak residual: {current_peak:.6f} Jy")
    print(f"Total iterations: {controller_interactive.total_iter_done}")
    print(f"Total major cycles: {controller_interactive.major_done}")
else:
    print(f"Parameter update failed: {msg}")

=== First Deconvolution Run ===
Initial parameters: niter=100, threshold=0.05 Jy

First run completed: Reached global stopping threshold (within mask)
Peak residual: 0.047101 Jy
Total iterations: 29

=== Updating Parameters and Continuing ===
Parameters updated successfully:
  niter: 200
  threshold: 0.01 Jy
  nmajor: 5

Continuing deconvolution...

Second run completed: Reached global stopping threshold (within mask)
Final peak residual: 0.009698 Jy
Total iterations: 44
Total major cycles: 5


## Example 4: ReturnDict Utilities

Demonstrate working with ReturnDict objects: creating, merging, and extracting statistics.

### Creating ReturnDict for Multiple Planes

In [8]:
# Create ReturnDict with multiple time/pol/chan entries
rd_multi = ReturnDict()

# Simulate results from 3 channels, 2 polarizations
np.random.seed(42)
for chan in range(3):
    for pol in range(2):
        peak = 1.0 - 0.1 * chan - 0.05 * pol + np.random.uniform(-0.05, 0.05)
        iters = np.random.randint(80, 120)
        
        rd_multi.add({
            'peakres': peak,
            'peakres_nomask': peak * 1.1,
            'masksum': 100 - chan * 10,
            'iter_done': iters,
            'max_psf_sidelobe': 0.2,
            'niter': iters,
            'threshold': 0.0,
            'loop_gain': 0.1,
        }, time=0, pol=pol, chan=chan)

print("Created ReturnDict with entries:")
for key in rd_multi.data.keys():
    entry = rd_multi.data[key]
    print(f"  {key}: peakres={entry['peakres']:.4f}, iter_done={entry['iter_done']}, masksum={entry['masksum']}")

Created ReturnDict with entries:
  Key(time=0, pol=0, chan=0): peakres=0.9875, iter_done=108, masksum=100
  Key(time=0, pol=1, chan=0): peakres=0.9183, iter_done=87, masksum=100
  Key(time=0, pol=0, chan=1): peakres=0.9099, iter_done=118, masksum=90
  Key(time=0, pol=1, chan=1): peakres=0.8446, iter_done=102, masksum=90
  Key(time=0, pol=0, chan=2): peakres=0.7558, iter_done=103, masksum=80
  Key(time=0, pol=1, chan=2): peakres=0.7334, iter_done=119, masksum=80


### Extracting Statistics from ReturnDict

In [9]:
# Extract overall statistics
peak_residual = get_peak_residual_from_returndict(rd_multi)
total_iters = get_iterations_done_from_returndict(rd_multi)

print(f"Overall peak residual: {peak_residual:.6f} Jy")
print(f"Total iterations: {total_iters}\n")

# Extract statistics for specific channel
peak_chan1 = get_peak_residual_from_returndict(rd_multi, chan=1)
iters_chan1 = get_iterations_done_from_returndict(rd_multi, chan=1)

print(f"Channel 1 peak residual: {peak_chan1:.6f} Jy")
print(f"Channel 1 iterations: {iters_chan1}")

Overall peak residual: 0.987454 Jy
Total iterations: 637

Channel 1 peak residual: 0.909866 Jy
Channel 1 iterations: 220


### Merging Multiple ReturnDicts

In [10]:
# Simulate results from 3 different workers processing different channels
rd_worker1 = create_return_dict(0.95, 100, 100, 0.2, time=0, pol=0, chan=0)
rd_worker2 = create_return_dict(0.85, 110, 100, 0.2, time=0, pol=0, chan=1)
rd_worker3 = create_return_dict(0.75, 95, 100, 0.2, time=0, pol=0, chan=2)

print("Merging ReturnDicts from 3 workers...\n")

# Merge with 'latest' strategy (default)
merged_rd = merge_return_dicts([rd_worker1, rd_worker2, rd_worker3], merge_strategy='latest')

print("Merged ReturnDict entries:")
for key in merged_rd.data.keys():
    entry = merged_rd.data[key]
    print(f"  {key}: peakres={entry['peakres']:.4f}, iter_done={entry['iter_done']}")

# Extract global statistics from merged result
global_peak = get_peak_residual_from_returndict(merged_rd)
global_iters = get_iterations_done_from_returndict(merged_rd)

print(f"\nGlobal peak residual: {global_peak:.4f} Jy")
print(f"Total iterations across all workers: {global_iters}")

Merging ReturnDicts from 3 workers...

Merged ReturnDict entries:
  Key(time=0, pol=0, chan=0): peakres=0.9500, iter_done=100
  Key(time=0, pol=0, chan=1): peakres=0.8500, iter_done=110
  Key(time=0, pol=0, chan=2): peakres=0.7500, iter_done=95

Global peak residual: 0.9500 Jy
Total iterations across all workers: 305


### Using Merged Results for Convergence Check

In [11]:
# Create controller and check convergence using merged results
controller_merge = IterationController(
    niter=500,
    nmajor=10,
    threshold=0.8,  # Will be above our merged peak
    gain=0.1,
)

# Update counts and check convergence
controller_merge.update_counts(merged_rd)
stopcode, stopdesc = controller_merge.check_convergence(merged_rd)

print(f"Convergence check result:")
print(f"  Stopcode: {stopcode}")
print(f"  Description: {stopdesc}")
print(f"  Iterations remaining: {controller_merge.niter}")
print(f"  Major cycles done: {controller_merge.major_done}")

Convergence check result:
  Stopcode: StopCode(major=0, minor=0)
  Description: Continue iterations
  Iterations remaining: 195
  Major cycles done: 1
