### Running Anomaly Detection with a Trained `ohana` Model

This notebook demonstrates how to load a pre-trained 3D U-Net model and use it to find anomalies in a sample H2RG exposure file. We will follow the core logic found in the `ohana.predict.predictor` module.

We will perform the following steps:
1.  **Set up paths** to the model, configuration, and data.
2.  **Load the `UNet3D` model** and its trained weights.
3.  **Load and preprocess** the exposure data.
4.  **Run patch-based inference** to generate a full prediction mask.
5.  **Extract anomaly locations** from the mask.
6.  **Visualize the results.**

##### Step 1: Imports and Configuration

First, let's import the necessary modules from the `ohana` package and other libraries. We also define the paths to our trained model, the configuration file, and the exposure data we want to analyze.

In [1]:
import torch
import yaml
import numpy as np
import os
from collections import OrderedDict
from tqdm import tqdm
from scipy.ndimage import label, center_of_mass
import matplotlib.pyplot as plt
import sys

In [None]:
print(f"Setting PyTorch to use {os.cpu_count() or 8} threads.")
torch.set_num_threads(os.cpu_count() or 8)
os.environ["OMP_NUM_THREADS"] = str(os.cpu_count() or 8)
os.environ["MKL_NUM_THREADS"] = str(os.cpu_count() or 8)

Setting PyTorch to use 16 threads.


In [3]:
sys.path.insert(0, '../')

In [4]:
# Import the necessary classes from your ohana package
# !NOTE: Make sure the 'ohana' directory is in your Python path

from ohana.models.unet_3d import UNet3D
from ohana.preprocessing.data_loader import DataLoader
from ohana.preprocessing.preprocessor import Preprocessor
from ohana.visualization.plotter import ResultVisualizer

In [5]:
""" Configuration """
# !NOTE: Replace these with the actual paths to your files.

# Path to the trained model
MODEL_PATH = "../trained_models/old_best_model_unet3d.pth"

# Path to the config file that was used for model training
CONFIG_PATH = "../configs/creator_config.yaml"

# Path to the exposure you want to run the predictions on
EXPOSURE_PATH = "/Volumes/jwst/ilongo/raw_data/18220_Euclid_SCA/ap30_100k_0p8m0p3_fullnoi_E001_18220.fits"

# Directory where model predictions will be stored
OUTPUT_DIR = "prediction_outputs"

In [6]:
# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


##### Step 2: Load the Trained Model

Next, we load the `UNet3D` model architecture and the saved weights from your `.pth` file. The code includes a step to handle models that were trained using `nn.DataParallel` on multiple GPUs. This logic comes directly from the `Predictor` class.

In [7]:
# Load the configuration file
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

# Initialize the model
# The number of classes should match the 'num_classes' in your config
model = UNet3D(n_channels=1, n_classes=config['num_classes'])

# Load the trained weights
print(f"Loading model from: {MODEL_PATH}")
state_dict = torch.load(MODEL_PATH, map_location=device)

# Handle models saved with nn.DataParallel
if next(iter(state_dict)).startswith('module.'):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
else:
    model.load_state_dict(state_dict)

model.to(device)
model.eval() # Set the model to evaluation mode
print("Model loaded successfully.")

Loading model from: ../trained_models/old_best_model_unet3d.pth
Model loaded successfully.


##### Step 3: Load and Preprocess the Exposure Data

We'll use the `DataLoader` to load the raw exposure cube and the `Preprocessor` to clean it and create the difference-image cube, which is the actual input to our model.

In [None]:
# Initialize data loader and preprocessor
data_loader = DataLoader()
preprocessor = Preprocessor()

# Load and process the exposure
print(f"Loading and preprocessing exposure: {EXPOSURE_PATH}")
raw_exposure = data_loader.load_exposure(EXPOSURE_PATH)
processed_cube = preprocessor.process_exposure(raw_exposure) # Shape: (T, H, W)

print(f"Processed data cube shape: {processed_cube.shape}")

ReferencePixelCorrector initialized with x_opt=64, y_opt=4.
Preprocessor initialized. Reference pixel correction: Enabled.
Loading and preprocessing exposure: /Volumes/jwst/ilongo/raw_data/18220_Euclid_SCA/ap30_100k_0p8m0p3_fullnoi_E001_18220.fits
Loading data from Multi-Extension FITS file: /Volumes/jwst/ilongo/raw_data/18220_Euclid_SCA/ap30_100k_0p8m0p3_fullnoi_E001_18220.fits


Loading FITS extensions:  65%|██████▍   | 292/450 [05:15<02:43,  1.03s/it]

##### Step 4: Run Patch-Based Inference

The model works on small, overlapping patches of the data cube. We will now extract these patches, run the model on each one, and stitch the results back together into a single, full-sized prediction mask. This process is identical to the `predict` method in the `Predictor` class.

In [None]:
BATCH_SIZE = 64
# Set device to 'cpu'
device = torch.device("cpu")

# --- Patch Extraction Logic ---
print("Extracting all patches...")
patch_size = tuple(config['patch_size'])
overlap = config['overlap']
step_h = patch_size[0] - overlap
step_w = patch_size[1] - overlap
_, H, W = processed_cube.shape
prediction_mask = np.zeros((H, W), dtype=np.uint8)

patches_with_coords = []
for r in range(0, H - patch_size[0] + 1, step_h):
    for c in range(0, W - patch_size[1] + 1, step_w):
        patch_data = processed_cube[:, r:r+patch_size[0], c:c+patch_size[1]]
        patches_with_coords.append((patch_data, (r, c)))
print(f"Extracted {len(patches_with_coords)} patches.")


# --- Run BATCHED Inference on CPU ---
print(f"Running CPU inference with a batch size of {BATCH_SIZE}...")
for i in tqdm(range(0, len(patches_with_coords), BATCH_SIZE)):
    # Get a "batch" of patches from the list
    batch_patches_with_coords = patches_with_coords[i:i+BATCH_SIZE]
    batch_patch_data = [item[0] for item in batch_patches_with_coords]
    
    # Stack patches into a single batch tensor
    input_tensor = torch.from_numpy(np.stack(batch_patch_data)).float().unsqueeze(1).to(device)

    # --- Normalize the ENTIRE batch at once ---
    b, c_in, t, h, w = input_tensor.shape
    tensor_flat = input_tensor.reshape(b, -1)
    min_val = tensor_flat.min(dim=1, keepdim=True)[0]
    max_val = tensor_flat.max(dim=1, keepdim=True)[0]
    normalized_tensor = (tensor_flat - min_val) / (max_val - min_val + 1e-6)
    normalized_tensor = normalized_tensor.view(b, c_in, t, h, w)

    # --- Run the model on the entire batch ---
    with torch.no_grad():
        logits = model(normalized_tensor)
        central_logits = logits[:, :, logits.shape[2] // 2, :, :]
        pred_batch_mask = torch.argmax(central_logits, dim=1).cpu().numpy()
    
    # --- Place the predicted patches into the full mask ---
    for j, (r, c) in enumerate([item[1] for item in batch_patches_with_coords]):
        pred_patch = pred_batch_mask[j]
        ph, pw = pred_patch.shape
        current_mask_region = prediction_mask[r:r+ph, c:c+pw]
        prediction_mask[r:r+ph, c:c+pw] = np.maximum(current_mask_region, pred_patch)

# --- SAVE ALL OUTPUTS ---
processed_data_path = os.path.join(OUTPUT_DIR, 'processed_cube.npy')
mask_path = os.path.join(OUTPUT_DIR, 'prediction_mask.npy')

np.save(processed_data_path, processed_cube)
print(f"✅ Processed data cube saved to: {processed_data_path}")

np.save(mask_path, prediction_mask)
print(f"✅ Full prediction mask saved to: {mask_path}")

Running inference on all patches...


  0%|          | 0/81 [00:01<?, ?it/s]


KeyboardInterrupt: 