### Running Anomaly Detection with a Trained `ohana` Model

This notebook demonstrates how to load a pre-trained 3D U-Net model and use it to find anomalies in a sample H2RG exposure file. We will follow the core logic found in the `ohana.predict.predictor` module.

We will perform the following steps:
1.  **Set up paths** to the model, configuration, and data.
2.  **Load the `UNet3D` model** and its trained weights.
3.  **Load and preprocess** the exposure data.
4.  **Run patch-based inference** to generate a full prediction mask.
5.  **Extract anomaly locations** from the mask.
6.  **Visualize the results.**

##### Step 1: Imports and Configuration

First, let's import the necessary modules from the `ohana` package and other libraries. We also define the paths to our trained model, the configuration file, and the exposure data we want to analyze.

In [3]:
import torch
import yaml
import numpy as np
import os
from collections import OrderedDict
from tqdm import tqdm
from scipy.ndimage import label, center_of_mass
import matplotlib.pyplot as plt
import sys
import json

In [5]:
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))

src_path = os.path.join(project_root, 'src')

if src_path not in sys.path:
    sys.path.insert(0, src_path)

In [None]:
print(f"Setting PyTorch to use {os.cpu_count() or 8} threads.")
torch.set_num_threads(os.cpu_count() or 8)
os.environ["OMP_NUM_THREADS"] = str(os.cpu_count() or 8)
os.environ["MKL_NUM_THREADS"] = str(os.cpu_count() or 8)

In [None]:
sys.path.insert(0, '../')

In [None]:
# Import the necessary classes from your ohana package
# !NOTE: Make sure the 'ohana' directory is in your Python path

from ohana.models.unet_3d import UNet3D
from ohana.predict.predictor import Predictor
from ohana.preprocessing.data_loader import DataLoader
from ohana.preprocessing.preprocessor import Preprocessor
from ohana.visualization.plotter import ResultVisualizer

In [None]:
""" Configuration """
# !NOTE: Replace these with the actual paths to your files.

# Path to the trained model
MODEL_PATH = "../trained_models/old_best_model_unet3d.pth"

# Path to the config file that was used for model training
CONFIG_PATH = "../configs/creator_config.yaml"

# Path to the exposure you want to run the predictions on
EXPOSURE_PATH = "/Volumes/jwst/ilongo/raw_data/18220_Euclid_SCA/ap30_100k_0p8m0p3_fullnoi_E001_18220.fits"

# Path to where the processed exposure is saved to (MUST BE .NPY)
PROCESSED_EXPOSURE_FILE = 'processed_ap30_100k_0p8m0p3_fullnoi_E001_18220.npy'

# Directory where model predictions will be stored
OUTPUT_DIR = "prediction_outputs"

# Where processed exposure will be saved
PROCESSED_DATA_PATH = os.path.join(OUTPUT_DIR, PROCESSED_EXPOSURE_FILE)

# Where prediction mask with be saved
MASK_PATH = os.path.join(OUTPUT_DIR, 'prediction_mask.npy')

# Where detections will be saved
DETECTIONS_PATH = os.path.join(OUTPUT_DIR, 'detections.json')

In [None]:
# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

##### Step 2: Load the Trained Model

Next, we load the `UNet3D` model architecture and the saved weights from your `.pth` file. The code includes a step to handle models that were trained using `nn.DataParallel` on multiple GPUs. This logic comes directly from the `Predictor` class.

In [None]:
# Load configuration
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

In [None]:
# Initialize the model architecture
model = UNet3D(n_channels=1, n_classes=config['num_classes'])

# Load trained weights, handling DataParallel prefixes if they exist
print(f"Loading model from: {MODEL_PATH}")
state_dict = torch.load(MODEL_PATH, map_location=device)

if next(iter(state_dict)).startswith('module.'):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_state_dict[k[7:]] = v
    model.load_state_dict(new_state_dict)
else:
    model.load_state_dict(state_dict)

model.to(device)
model.eval() # Set model to evaluation mode
print("Model loaded successfully.")

##### Step 3: Load and Preprocess the Exposure Data

We'll use the `DataLoader` to load the raw exposure cube and the `Preprocessor` to clean it and create the difference-image cube, which is the actual input to our model.

In [None]:
# Initialize our helper classes
data_loader = DataLoader()
preprocessor = Preprocessor()

In [None]:
# Load the raw data
raw_exposure = data_loader.load_exposure(EXPOSURE_PATH)

In [None]:
# Process the data. This will save the result to PROCESSED_DATA_PATH.
# If the file already exists, it will load it from the cache.
processed_cube = preprocessor.process_exposure(
    raw_exposure_cube=raw_exposure,
    save_path=PROCESSED_DATA_PATH   
)
print(f"Preprocessing complete. Cube shape: {processed_cube.shape}")

##### Step 4: Run Patch-Based Inference

The model works on small, overlapping patches of the data cube. We will now extract these patches, run the model on each one, and stitch the results back together into a single, full-sized prediction mask. This process is identical to the `predict` method in the `Predictor` class.

In [None]:
BATCH_SIZE = 8

In [None]:
# --- Patch Extraction ---
print("Extracting all patches...")
patch_size = tuple(config['patch_size'])
overlap = config['overlap']
step_h, step_w = patch_size[0] - overlap, patch_size[1] - overlap
_, H, W = processed_cube.shape
prediction_mask = np.zeros((H, W), dtype=np.uint8)

In [None]:
patches_with_coords = []
for r in range(0, H - patch_size[0] + 1, step_h):
    for c in range(0, W - patch_size[1] + 1, step_w):
        patches_with_coords.append((processed_cube[:, r:r+patch_size[0], c:c+patch_size[1]], (r, c)))

In [None]:
print(f"Running inference with batch size {BATCH_SIZE}...")
for i in tqdm(range(0, len(patches_with_coords), BATCH_SIZE)):
    batch_patches_with_coords = patches_with_coords[i:i+BATCH_SIZE]
    batch_patch_data = [item[0] for item in batch_patches_with_coords]
    
    input_tensor = torch.from_numpy(np.stack(batch_patch_data)).float().unsqueeze(1).to(device)

    # Normalize batch
    b, c_in, t, h, w = input_tensor.shape
    tensor_flat = input_tensor.reshape(b, -1)
    min_val, max_val = tensor_flat.min(dim=1, keepdim=True)[0], tensor_flat.max(dim=1, keepdim=True)[0]
    normalized_tensor = (tensor_flat - min_val) / (max_val - min_val + 1e-6)
    normalized_tensor = normalized_tensor.view(b, c_in, t, h, w)

    # Predict
    with torch.no_grad():
        logits = model(normalized_tensor)
        central_logits = logits[:, :, logits.shape[2] // 2, :, :]
        pred_batch_mask = torch.argmax(central_logits, dim=1).cpu().numpy()
    
    # Stitch results into the main mask
    for j, (r, c) in enumerate([item[1] for item in batch_patches_with_coords]):
        pred_patch = pred_batch_mask[j]
        ph, pw = pred_patch.shape
        prediction_mask[r:r+ph, c:c+pw] = np.maximum(prediction_mask[r:r+ph, c:c+pw], pred_patch)

# --- Save the final mask ---
np.save(MASK_PATH, prediction_mask)
print(f"Inference complete. Prediction mask saved to: {MASK_PATH}")

##### Step 5: Extract Anomaly Locations

Now that we have the full 2D prediction mask, we use a connected-components algorithm to find the center of each detected anomaly region.

In [None]:
detections = []
class_map_inv = {v: k for k, v in config.get('class_map', {}).items()}

for class_idx, class_name in class_map_inv.items():
    if class_idx == 0: continue
    class_mask = (prediction_mask == class_idx).astype(int)
    labeled_array, num_features = label(class_mask)
    if num_features > 0:
        centers = center_of_mass(class_mask, labeled_array, range(1, num_features + 1))
        for center in centers:
            detections.append({'type': class_name, 'location_px': [int(round(c)) for c in center]})

print(f"Found {len(detections)} objects. Saving detections list...")
with open(DETECTIONS_PATH, 'w') as f:
    json.dump(detections, f, indent=4)
print(f"Detections list saved to: {DETECTIONS_PATH}")

##### Step 6: Visualize the Results

Finally, we use the `ResultVisualizer` to overlay our outputs on the processed data cube. Since all files were saved in the previous steps, this cell simply loads them for plotting.

In [None]:
print("Generating visualization...")
visualizer = ResultVisualizer(
    processed_data_path=PROCESSED_DATA_PATH,
    prediction_mask_path=MASK_PATH
)
visualizer.load_detection_list(results_path=DETECTIONS_PATH)
visualizer.plot_full_mask_overlay()