# Review and Filter Masks

This notebook allows you to visually review each mask and mark poor quality masks for exclusion from further analysis.

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.widgets import Button
from PIL import Image

## Configuration

In [None]:
# Camera and directory configuration
camera_id = 'WI_East_Branch_Pecatonica_River_near_Blanchardville_Bullet'
images_dir = f'{camera_id}/images'
masks_dir = f'{camera_id}/masks'
csv_path = f'{camera_id}/images_and_data.csv'

print(f"Loading data from: {csv_path}")

## Load Data

In [None]:
# Load the CSV
df = pd.read_csv(csv_path)

# Initialize quality_flag column if it doesn't exist
if 'quality_flag' not in df.columns:
    df['quality_flag'] = 'good'  # Default all to 'good'
    print("Added 'quality_flag' column (all set to 'good' by default)")
else:
    print("'quality_flag' column already exists")

print(f"\nLoaded {len(df)} records")
print(f"Current quality flags:")
print(df['quality_flag'].value_counts())

display(df[['image_names', 'mask_filename', '00065', 'quality_flag']].head())

## Helper Functions

In [None]:
def load_mask(mask_path):
    """Load a mask from a .npy file."""
    mask = np.load(mask_path)
    if mask.ndim == 3:
        mask = mask.squeeze()
    return mask.astype(bool)


def overlay_mask_on_image(image, mask, color=(0, 255, 255), alpha=0.4):
    """
    Overlay a mask on an image.
    
    Args:
        image: PIL Image or numpy array
        mask: Binary mask (boolean or 0/1)
        color: RGB color tuple for the mask
        alpha: Transparency of the mask overlay
    
    Returns:
        numpy array with mask overlay
    """
    # Convert image to numpy array if needed
    if isinstance(image, Image.Image):
        img_array = np.array(image)
    else:
        img_array = image.copy()
    
    # Ensure mask is 2D boolean
    if mask.ndim == 3:
        mask = mask.squeeze()
    mask_bool = mask.astype(bool)
    
    # Create colored overlay
    overlay = img_array.copy()
    overlay[mask_bool] = color
    
    # Blend with original image
    result = img_array.copy()
    result = (alpha * overlay + (1 - alpha) * img_array).astype(np.uint8)
    
    return result

## Review All Masks (Grid View)

View all masks in a grid to get an overview and identify problematic ones.

In [None]:
# Create a grid view of all masks
n_images = len(df)
n_cols = 6
n_rows = (n_images + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, n_rows * 3))
axes = axes.flatten() if n_images > 1 else [axes]

print("Loading images and masks for grid view...")

for idx, row in df.iterrows():
    ax = axes[idx]
    
    # Load image and mask
    image_path = os.path.join(images_dir, row['image_names'])
    mask_path = os.path.join(masks_dir, row['mask_filename'])
    
    if os.path.exists(image_path) and os.path.exists(mask_path):
        image = Image.open(image_path)
        mask = load_mask(mask_path)
        
        # Create overlay
        overlay = overlay_mask_on_image(image, mask, color=(0, 255, 255), alpha=0.5)
        
        # Display
        ax.imshow(overlay)
        
        # Color-code title based on quality flag
        quality = row['quality_flag']
        title_color = 'green' if quality == 'good' else 'red'
        ax.set_title(f"#{idx} - {quality}\nGH: {row['00065']:.2f} ft", 
                     fontsize=8, color=title_color, weight='bold')
    else:
        ax.text(0.5, 0.5, f'Missing\n#{idx}', ha='center', va='center')
        ax.set_title(f"#{idx} - MISSING", fontsize=8, color='orange')
    
    ax.axis('off')

# Hide extra subplots
for idx in range(n_images, len(axes)):
    axes[idx].axis('off')

plt.tight_layout()
plt.suptitle('All Masks Overview (Green=Good, Red=Bad)', fontsize=16, weight='bold', y=1.001)
plt.show()

print(f"\nDisplayed {n_images} image/mask pairs")
print("\nNote the image numbers (#) of any bad masks you see.")

## Mark Bad Masks

Enter the indices of masks you want to mark as 'bad' based on the grid view above.

In [None]:
# Enter the indices of bad masks here (from the grid view above)
# Example: bad_indices = [5, 12, 23, 45]
bad_indices = []  # UPDATE THIS LIST

if bad_indices:
    # Mark the specified indices as bad
    df.loc[bad_indices, 'quality_flag'] = 'bad'
    print(f"Marked {len(bad_indices)} masks as 'bad': {bad_indices}")
    
    print("\nUpdated quality flags:")
    print(df['quality_flag'].value_counts())
else:
    print("No bad indices specified. All masks remain marked as 'good'.")

## Detailed Review (Individual Images)

Review individual images in detail. You can loop through specific indices or all images.

In [None]:
# Review specific indices in detail
# Leave empty to review all, or specify indices like [0, 5, 10]
review_indices = []  # UPDATE THIS or leave empty for all

if not review_indices:
    review_indices = range(len(df))

for idx in review_indices:
    row = df.iloc[idx]
    
    # Load image and mask
    image_path = os.path.join(images_dir, row['image_names'])
    mask_path = os.path.join(masks_dir, row['mask_filename'])
    
    if not (os.path.exists(image_path) and os.path.exists(mask_path)):
        print(f"\nSkipping #{idx} - files not found")
        continue
    
    image = Image.open(image_path)
    mask = load_mask(mask_path)
    overlay = overlay_mask_on_image(image, mask, color=(0, 255, 255), alpha=0.4)
    
    # Create figure with three panels
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    
    # Original image
    ax1.imshow(image)
    ax1.set_title('Original Image', fontsize=12, weight='bold')
    ax1.axis('off')
    
    # Mask only
    ax2.imshow(mask, cmap='gray')
    ax2.set_title('Mask', fontsize=12, weight='bold')
    ax2.axis('off')
    
    # Overlay
    ax3.imshow(overlay)
    ax3.set_title('Overlay', fontsize=12, weight='bold')
    ax3.axis('off')
    
    # Overall title with metadata
    quality = row['quality_flag']
    title_color = 'green' if quality == 'good' else 'red'
    fig.suptitle(
        f"Image #{idx} - Quality: {quality.upper()}\n"
        f"Gage Height: {row['00065']:.2f} ft | Discharge: {row['00060']:.0f} cfs\n"
        f"{row['image_names']}",
        fontsize=13, weight='bold', color=title_color
    )
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nImage #{idx} - Current quality: {quality}")
    print(f"To mark as bad, add {idx} to the bad_indices list in the cell above.")
    print("-" * 80)

## Manual Quality Flag Updates

You can manually update individual quality flags here.

In [None]:
# Manually set quality flags for specific indices
# Examples:
# df.loc[5, 'quality_flag'] = 'bad'
# df.loc[12, 'quality_flag'] = 'good'

# Or update multiple at once:
# df.loc[[5, 12, 23], 'quality_flag'] = 'bad'

print("Current quality flag distribution:")
print(df['quality_flag'].value_counts())
print(f"\nBad masks: {df[df['quality_flag'] == 'bad'].index.tolist()}")

## View All Bad Masks

Display all masks currently marked as 'bad' for confirmation.

In [None]:
bad_df = df[df['quality_flag'] == 'bad']

if len(bad_df) == 0:
    print("No masks marked as 'bad'")
else:
    print(f"Found {len(bad_df)} masks marked as 'bad':\n")
    
    n_bad = len(bad_df)
    n_cols = min(4, n_bad)
    n_rows = (n_bad + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(16, n_rows * 4))
    if n_bad == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for i, (idx, row) in enumerate(bad_df.iterrows()):
        ax = axes[i]
        
        # Load image and mask
        image_path = os.path.join(images_dir, row['image_names'])
        mask_path = os.path.join(masks_dir, row['mask_filename'])
        
        if os.path.exists(image_path) and os.path.exists(mask_path):
            image = Image.open(image_path)
            mask = load_mask(mask_path)
            overlay = overlay_mask_on_image(image, mask, color=(255, 0, 0), alpha=0.5)  # Red for bad
            
            ax.imshow(overlay)
            ax.set_title(f"#{idx}\nGH: {row['00065']:.2f} ft", fontsize=10, color='red', weight='bold')
        else:
            ax.text(0.5, 0.5, f'Missing #{idx}', ha='center', va='center')
        
        ax.axis('off')
    
    # Hide extra subplots
    for i in range(n_bad, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.suptitle('All Masks Marked as BAD', fontsize=16, weight='bold', color='red', y=1.001)
    plt.show()
    
    print("\nBad mask indices:", bad_df.index.tolist())
    display(bad_df[['image_names', '00065', '00060', 'quality_flag']])

## Quality Flag Statistics

In [None]:
print("="*60)
print("QUALITY FLAG STATISTICS")
print("="*60)

quality_counts = df['quality_flag'].value_counts()
print(f"\nTotal images: {len(df)}")
print(f"Good masks: {quality_counts.get('good', 0)} ({quality_counts.get('good', 0)/len(df)*100:.1f}%)")
print(f"Bad masks: {quality_counts.get('bad', 0)} ({quality_counts.get('bad', 0)/len(df)*100:.1f}%)")

if 'bad' in quality_counts and quality_counts['bad'] > 0:
    bad_indices = df[df['quality_flag'] == 'bad'].index.tolist()
    print(f"\nBad mask indices: {bad_indices}")
    
    # Show elevation range of bad masks
    bad_elevations = df[df['quality_flag'] == 'bad']['00065']
    print(f"\nElevation range of bad masks: {bad_elevations.min():.2f} - {bad_elevations.max():.2f} ft")

print("="*60)

## Save Updated CSV

Save the quality flags back to the CSV file.

In [None]:
# Save the updated dataframe back to CSV
df.to_csv(csv_path, index=False)

print(f"Saved updated quality flags to: {csv_path}")
print(f"\nFinal quality flag counts:")
print(df['quality_flag'].value_counts())

if df['quality_flag'].eq('bad').any():
    print(f"\nNote: Notebook 03 will automatically exclude masks marked as 'bad'.")
else:
    print(f"\nAll masks marked as 'good' - no filtering will occur in notebook 03.")

## Summary

In [None]:
print("="*70)
print("MASK REVIEW COMPLETE")
print("="*70)
print(f"Camera: {camera_id}")
print(f"Total masks reviewed: {len(df)}")
print(f"Good masks: {df['quality_flag'].eq('good').sum()}")
print(f"Bad masks: {df['quality_flag'].eq('bad').sum()}")
print(f"\nUpdated CSV: {csv_path}")
print("\nYou can now proceed to notebook 03 for elevation map creation.")
print("Bad masks will be automatically filtered out.")
print("="*70)