# AoflaggerTrackingPlugin Results Analysis

This notebook analyzes the results from the AoflaggerTrackingPlugin by loading the stored context and visualizing:
- Track visibility data for each receiver
- Flags before and after AOFlagger processing
- Flagging statistics and comparisons

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
from pathlib import Path

# MuSEEK imports
from museek.enums.result_enum import ResultEnum
from museek.time_ordered_data import TimeOrderedData

# Set up plotting
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12

print("Imports completed successfully!")

## 1. Load Context Data

Load the stored context from the pipeline run. Update the path below to match your configuration.

In [None]:
# Update these paths based on your test_calibrator_flagging.py config
context_folder = "/idia/users/msantos/museek"  # From your config
block_name = "1675021905"  # From your config

# Find the AoflaggerTrackingPlugin context file
context_path = Path(context_folder) / block_name
aoflagger_file = context_path / "aoflagger_tracking_plugin.pickle"

print(f"Looking for context in: {context_path}")
print(f"AoflaggerTrackingPlugin file: {aoflagger_file}")
print(f"File exists: {aoflagger_file.exists()}")

# List available context files
if context_path.exists():
    context_files = list(context_path.glob("*.pickle"))
    print(f"\nAvailable context files:")
    for f in context_files:
        print(f"  - {f.name}")
else:
    print(f"Context path does not exist: {context_path}")

In [None]:
# Load the context data using direct pickle loading
if aoflagger_file.exists():
    # Load the pickle file directly
    with open(aoflagger_file, 'rb') as f:
        data_read = pickle.load(f)
    
    # Access track data using ResultEnum
    track_data = data_read.get(ResultEnum.TRACK_DATA).result
    
    # Load calibrator information for splitting visualizations
    calibrator_validated_periods = data_read.get(ResultEnum.CALIBRATOR_VALIDATED_PERIODS).result
    calibrator_dump_indices = data_read.get(ResultEnum.CALIBRATOR_DUMP_INDICES).result
    
    print(f"Successfully loaded track data!")
    print(f"Track data shape: {track_data.visibility.shape}")
    print(f"Receivers: {[r.name for r in track_data.receivers]}")
    print(f"Antennas: {[a.name for a in track_data.antennas]}")
    print(f"Number of time dumps: {track_data.visibility.shape[0]}")
    print(f"Number of frequencies: {track_data.visibility.shape[1]}")
    print(f"Number of receivers: {track_data.visibility.shape[2]}")
    print(f"Calibrator periods: {calibrator_validated_periods}")
else:
    print("ERROR: Context file not found. Make sure the pipeline ran successfully with do_store_context=True")
    track_data = None

## 2. Load Flag Information

Extract flag information before and after AOFlagger processing

In [None]:
if track_data is not None:
    # Load visibility and flag data
    track_data.load_visibility_flags_weights(polars='auto')
    
    # Get all flags (combined from all flagging stages)
    all_flags = track_data.flags.combine(threshold=1)
    
    print(f"Flag data loaded successfully!")
    print(f"Total number of flag types: {len(track_data.flags)}")
    flag_name_list = data_read.get(ResultEnum.FLAG_NAME_LIST).result
    print(f"Applied flags: {flag_name_list}")
    # Print flag statistics for each receiver
    for i_recv, receiver in enumerate(track_data.receivers):
        receiver_flags = all_flags.get(recv=i_recv).squeeze
        flag_fraction = np.sum(receiver_flags) / receiver_flags.size
        print(f"{receiver.name}: {flag_fraction:.1%} flagged ({np.sum(receiver_flags)} / {receiver_flags.size} samples)")

## 3. Visualize Track Visibility Data

Plot the track visibility data for each receiver as waterfall plots

In [None]:
if track_data is not None:
    # Get frequency array in MHz
    frequencies = track_data.frequencies.squeeze / 1e6  # Convert Hz to MHz
    
    # Get calibrator names from config (or use default labels)
    calibrator_names = ['HydraA', 'PictorA']  # From your config
    
    # Create subplots for each receiver and each calibrator period
    n_receivers = len(track_data.receivers)
    n_periods = len(calibrator_validated_periods)
    
    fig, axes = plt.subplots(n_receivers, n_periods, figsize=(8*n_periods, 6*n_receivers))
    
    # Handle single receiver case
    if n_receivers == 1:
        axes = axes.reshape(1, -1) if n_periods > 1 else [[axes]]
    # Handle single period case  
    if n_periods == 1:
        axes = axes.reshape(-1, 1) if n_receivers > 1 else [[axes]]
    
    # Get original dumps for mapping
    dumps = np.array(track_data._dumps_of_scan_state())
    
    for i_recv, receiver in enumerate(track_data.receivers):
        # Get visibility and flag data for this receiver
        vis_data = track_data.visibility.get(recv=i_recv).squeeze
        flags_data = all_flags.get(recv=i_recv).squeeze
        
        for i_period, period in enumerate(calibrator_validated_periods):
            # Get dump indices for this calibrator period
            period_dump_indices = calibrator_dump_indices[period]
            
            # Find corresponding indices in track_data
            period_mask = np.isin(dumps, period_dump_indices)
            period_vis_data = vis_data[period_mask]
            period_flags_data = flags_data[period_mask]
            
            # Skip if no data for this period
            if period_vis_data.size == 0:
                axes[i_recv][i_period].text(0.5, 0.5, 'No data for this period', 
                                           ha='center', va='center', transform=axes[i_recv][i_period].transAxes)
                axes[i_recv][i_period].set_title(f'{receiver.name} - {calibrator_names[i_period] if i_period < len(calibrator_names) else f"Period {i_period}"}')
                continue
            
            # Create masked array to exclude flagged data (RFI)
            masked_vis_data = np.ma.masked_array(period_vis_data.real, mask=period_flags_data)
            
            # Plot waterfall for this period with flagged data masked out
            im = axes[i_recv][i_period].imshow(masked_vis_data.T, aspect='auto', origin='lower',
                          extent=[0, masked_vis_data.shape[0], frequencies.min(), frequencies.max()],
                          cmap='viridis', interpolation='nearest')
            
            axes[i_recv][i_period].set_ylabel('Frequency [MHz]')
            calibrator_name = calibrator_names[i_period] if i_period < len(calibrator_names) else f"Period {i_period}"
            
            # Calculate flagging statistics for title
            flag_percentage = 100 * np.sum(period_flags_data) / period_flags_data.size
            
            axes[i_recv][i_period].set_title(f'{receiver.name} - {calibrator_name} (Clean Data)\n'
                                           f'{len(period_dump_indices)} samples, {flag_percentage:.1f}% flagged')
            
            # Add colorbar for each subplot
            plt.colorbar(im, ax=axes[i_recv][i_period], label='Visibility Amplitude', fraction=0.046)
    
    # Set x-label for bottom row
    for i_period in range(n_periods):
        axes[-1][i_period].set_xlabel('Time Sample (within period)')
    
    plt.suptitle('Track Visibility Data (Flagged RFI Excluded) - Split by Calibrator Periods', fontsize=16)
    plt.tight_layout()
    plt.show()

## 3b. Visibility vs Time (Median Across Frequency)

Plot the visibility amplitude vs time using the median across frequency after flagging. This shows the time evolution of the calibrator signal strength.

In [None]:
if track_data is not None:
    # Get calibrator names from config
    calibrator_names = ['HydraA', 'PictorA']  # From your config
    
    # Create subplots for each receiver and each calibrator period (side by side)
    n_receivers = len(track_data.receivers)
    n_periods = len(calibrator_validated_periods)
    
    fig, axes = plt.subplots(n_receivers, n_periods, figsize=(8*n_periods, 6*n_receivers))
    
    # Handle single receiver case
    if n_receivers == 1:
        axes = axes.reshape(1, -1) if n_periods > 1 else [[axes]]
    # Handle single period case  
    if n_periods == 1:
        axes = axes.reshape(-1, 1) if n_receivers > 1 else [[axes]]
    
    # Get original dumps and timestamps for mapping
    dumps = np.array(track_data._dumps_of_scan_state())
    timestamps = track_data.timestamps.squeeze  # Unix timestamps
    
    for i_recv, receiver in enumerate(track_data.receivers):
        # Get visibility and flag data for this receiver
        vis_data = track_data.visibility.get(recv=i_recv).squeeze
        flags_data = all_flags.get(recv=i_recv).squeeze
        
        for i_period, period in enumerate(calibrator_validated_periods):
            # Get dump indices for this calibrator period
            period_dump_indices = calibrator_dump_indices[period]
            
            # Find corresponding indices in track_data
            period_mask = np.isin(dumps, period_dump_indices)
            period_vis_data = vis_data[period_mask]
            period_flags_data = flags_data[period_mask]
            period_timestamps = timestamps[period_mask]
            
            # Skip if no data for this period
            if period_vis_data.size == 0:
                axes[i_recv][i_period].text(0.5, 0.5, 'No data for this period', 
                                           ha='center', va='center', transform=axes[i_recv][i_period].transAxes)
                calibrator_name = calibrator_names[i_period] if i_period < len(calibrator_names) else f"Period {i_period}"
                axes[i_recv][i_period].set_title(f'{receiver.name} - {calibrator_name}')
                continue
            
            # Create masked array to exclude flagged data
            masked_vis_data = np.ma.masked_array(np.abs(period_vis_data), mask=period_flags_data)
            
            # Calculate median across frequency for each time sample
            vis_median_vs_time = np.ma.median(masked_vis_data, axis=1)
            
            # Convert timestamps to minutes relative to start of period
            period_start_time = period_timestamps[0]
            time_minutes = (period_timestamps - period_start_time) / 60.0  # Convert seconds to minutes
            
            # Plot visibility vs time for this period
            calibrator_name = calibrator_names[i_period] if i_period < len(calibrator_names) else f"Period {i_period}"
            
            # Plot only points (no lines)
            axes[i_recv][i_period].plot(time_minutes, vis_median_vs_time, 'o', color='blue', 
                                       markersize=4, alpha=0.7)
            
            # Formatting
            axes[i_recv][i_period].set_ylabel('Visibility Amplitude\n(Median across frequency)')
            axes[i_recv][i_period].set_title(f'{receiver.name} - {calibrator_name}\n'
                                           f'({len(period_dump_indices)} samples, '
                                           f'{time_minutes[-1]:.1f} min duration)')
            axes[i_recv][i_period].grid(True, alpha=0.3)
            
            # Set y-axis to log scale for better dynamic range
            axes[i_recv][i_period].set_yscale('log')
            
    
    # Set x-label for bottom row
    for i_period in range(n_periods):
        axes[-1][i_period].set_xlabel('Time (minutes from period start)')
    
    plt.suptitle('Track Visibility vs Time (Flagged RFI Excluded) - Split by Calibrator Periods', fontsize=16)
    plt.tight_layout()
    plt.show()

## 3c. Elevation vs Time

Plot the elevation vs time for each dish. Tracks only and all calibrators.

In [None]:
if track_data is not None:
    # Get calibrator names from config
    calibrator_names = ['HydraA', 'PictorA']  # From your config
    
    # Create subplots for each antenna and each calibrator period (side by side)
    n_antennas = len(track_data.antennas)
    n_periods = len(calibrator_validated_periods)
    
    fig, axes = plt.subplots(n_antennas, n_periods, figsize=(8*n_periods, 6*n_antennas))
    
    # Handle single antenna case
    if n_antennas == 1:
        axes = axes.reshape(1, -1) if n_periods > 1 else [[axes]]
    # Handle single period case  
    if n_periods == 1:
        axes = axes.reshape(-1, 1) if n_antennas > 1 else [[axes]]
    
    # Get original dumps and timestamps for mapping
    dumps = np.array(track_data._dumps_of_scan_state())
    timestamps = track_data.timestamps.squeeze  # Unix timestamps
    
    for i_ant, antenna in enumerate(track_data.antennas):
        # Get elevation data for this antenna
        elevation_data = track_data.elevation.get(recv=i_ant).squeeze
        
        for i_period, period in enumerate(calibrator_validated_periods):
            # Get dump indices for this calibrator period
            period_dump_indices = calibrator_dump_indices[period]
            
            # Find corresponding indices in track_data
            period_mask = np.isin(dumps, period_dump_indices)
            period_elevation = elevation_data[period_mask]
            period_timestamps = timestamps[period_mask]
            
            # Skip if no data for this period
            if period_elevation.size == 0:
                axes[i_ant][i_period].text(0.5, 0.5, 'No data for this period', 
                                           ha='center', va='center', transform=axes[i_ant][i_period].transAxes)
                calibrator_name = calibrator_names[i_period] if i_period < len(calibrator_names) else f"Period {i_period}"
                axes[i_ant][i_period].set_title(f'{antenna.name} - {calibrator_name}')
                continue
            
            # Convert timestamps to minutes relative to start of period
            period_start_time = period_timestamps[0]
            time_minutes = (period_timestamps - period_start_time) / 60.0  # Convert seconds to minutes
                        
            # Plot elevation vs time for this period
            calibrator_name = calibrator_names[i_period] if i_period < len(calibrator_names) else f"Period {i_period}"
            
            # Plot only points (no lines)
            axes[i_ant][i_period].plot(time_minutes, period_elevation, 'o', color='green', 
                                       markersize=4, alpha=0.7)
            
            # Formatting
            axes[i_ant][i_period].set_ylabel('Elevation [degrees]')
            axes[i_ant][i_period].set_title(f'{antenna.name} - {calibrator_name}\n'
                                           f'({len(period_dump_indices)} samples, '
                                           f'{time_minutes[-1]:.1f} min duration)')
            axes[i_ant][i_period].grid(True, alpha=0.3)
                
    # Set x-label for bottom row
    for i_period in range(n_periods):
        axes[-1][i_period].set_xlabel('Time (minutes from period start)')
    
    plt.suptitle('Antenna Elevation vs Time - Split by Calibrator Periods', fontsize=16)
    plt.tight_layout()
    plt.show()

## 4. Frequency-Domain Analysis

Analyze flagging patterns in the frequency domain

In [None]:
if track_data is not None:
    fig, axes = plt.subplots(2, 1, figsize=(15, 10))
    
    for i_recv, receiver in enumerate(track_data.receivers):
        flags = all_flags.get(recv=i_recv).squeeze
        vis_data = track_data.visibility.get(recv=i_recv).squeeze
        
        # Flag fraction per frequency channel
        flag_fraction_per_freq = np.sum(flags, axis=0) / flags.shape[0]
        
        # Mean visibility amplitude per frequency (unflagged data only)
        masked_vis = np.ma.masked_array(np.abs(vis_data), mask=flags)
        mean_vis_per_freq = np.ma.mean(masked_vis, axis=0)
        
        # Plot flag fraction vs frequency
        axes[0].plot(frequencies, flag_fraction_per_freq, 'o-', 
                    label=f'{receiver.name}', markersize=3)
        
        # Plot mean visibility vs frequency
        axes[1].plot(frequencies, mean_vis_per_freq, 'o-',
                    label=f'{receiver.name}', markersize=3)
    
    axes[0].set_ylabel('Flag Fraction')
    axes[0].set_title('Flag Fraction per Frequency Channel')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend()
    axes[0].set_ylim(0, 1)
    
    axes[1].set_xlabel('Frequency [MHz]')
    axes[1].set_ylabel('Mean Visibility Amplitude')
    axes[1].set_title('Mean Visibility Amplitude per Frequency Channel (Unflagged Data)')
    axes[1].grid(True, alpha=0.3)
    axes[1].legend()
    axes[1].set_yscale('log')
    
    plt.tight_layout()
    plt.show()

## 5. Time-Domain Analysis

Analyze flagging patterns in the time domain

In [None]:
if track_data is not None:
    # Get original dumps for mapping to calibrator periods only
    dumps = np.array(track_data._dumps_of_scan_state())
    
    # Collect all calibrator dump indices
    all_calibrator_dump_indices = []
    for period in calibrator_validated_periods:
        all_calibrator_dump_indices.extend(calibrator_dump_indices[period])
    
    # Create mask for calibrator periods only
    calibrator_mask = np.isin(dumps, all_calibrator_dump_indices)
    
    print(f"Total time samples in track data: {len(dumps)}")
    print(f"Calibrator time samples: {np.sum(calibrator_mask)}")
    print(f"Using only calibrator periods for time-domain analysis")
    
    fig, axes = plt.subplots(2, 1, figsize=(15, 10))
    
    for i_recv, receiver in enumerate(track_data.receivers):
        # Get flags and visibility data for this receiver
        flags_all = all_flags.get(recv=i_recv).squeeze
        vis_data_all = track_data.visibility.get(recv=i_recv).squeeze
        
        # Extract only calibrator periods
        flags = flags_all[calibrator_mask]
        vis_data = vis_data_all[calibrator_mask]
        
        # Flag fraction per time sample (calibrator periods only)
        flag_fraction_per_time = np.sum(flags, axis=1) / flags.shape[1]
        
        # Mean visibility amplitude per time sample (unflagged data only)
        masked_vis = np.ma.masked_array(np.abs(vis_data), mask=flags)
        mean_vis_per_time = np.ma.mean(masked_vis, axis=1)
        
        # Create time indices for calibrator samples
        time_indices = np.arange(len(flag_fraction_per_time))
        
        # Plot flag fraction vs time
        axes[0].plot(time_indices, flag_fraction_per_time, 'o-', 
                    label=f'{receiver.name}', markersize=3)
        
        # Plot mean visibility vs time
        axes[1].plot(time_indices, mean_vis_per_time, 'o-',
                    label=f'{receiver.name}', markersize=3)
    
    axes[0].set_ylabel('Flag Fraction')
    axes[0].set_title('Flag Fraction per Time Sample (Calibrator Periods Only)')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend()
    axes[0].set_ylim(0, 1)
    
    axes[1].set_xlabel('Time Sample (Calibrator Periods)')
    axes[1].set_ylabel('Mean Visibility Amplitude')
    axes[1].set_title('Mean Visibility Amplitude per Time Sample (Calibrator Periods, Unflagged Data)')
    axes[1].grid(True, alpha=0.3)
    axes[1].legend()
    axes[1].set_yscale('log')
    
    plt.tight_layout()
    plt.show()