In [None]:
import os
import glob
import cv2
import numpy as np
import pandas as pd
import torch
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
# Path to the folder containing the images you want to analyze
IMAGE_FOLDER_PATH = "Data/images/"

IMAGE_SPEC = "Obsidian_old"

IMAGE_PATH = os.path.join(IMAGE_FOLDER_PATH, IMAGE_SPEC)
# Path to the trained model file (.pth) saved from the training notebook
MODEL_FOLDER_PATH = "Data/segmentation_models/"

MODEL_SPEC = "obsidian_f1a05b05.pth"

MODEL_PATH = os.path.join(MODEL_FOLDER_PATH, MODEL_SPEC)

# Path to the directory where results  will be saved
OUTPUT_PATH = "Data/inference_results"

# ---  MODEL PARAMETERS (These MUST match the model you are loading) ---
MODEL_ARC = "Unet"
ENCODER = "resnet34"
ENCODER_WEIGHTS = "imagenet"

# --- 3. PREPROCESSING PARAMETERS (These MUST match how the model was trained) ---
# Image dimensions the model was trained on
IMG_HEIGHT = 1024
IMG_WIDTH = 1024
# Number of channels in the *source* images (1 for grayscale, 3 for color).
INPUT_CHANNELS_CONFIG = 3

# --- 4. INFERENCE PARAMETERS ---
# Probability threshold to convert model output to a binary mask
PREDICTION_THRESHOLD = 0.5
# Set to None if unknown.
PIXEL_RESOLUTION_UM_PER_PX = None

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUTPUT_PATH, exist_ok=True)

if not os.path.isdir(IMAGE_PATH):
    raise FileNotFoundError(f"ERROR: Image folder not found at '{IMAGE_PATH}'")
if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"ERROR: Model file not found at '{MODEL_PATH}'")

print(f"Configuration Loaded:")
print(f"  Image Folder: {IMAGE_PATH}")
print(f"  Model Path: {MODEL_PATH}")
print(f"  Output Directory: {OUTPUT_PATH}")
print(f"  Using Device: {DEVICE}")

In [None]:
def get_inference_augs(height, width):
    """Defines transformations for inference (resizing)."""
    return A.Compose([
        A.Resize(height, width, interpolation=cv2.INTER_LINEAR, always_apply=True),
    ])

def get_preprocessing(preprocessing_fn):
    """Combines model-specific normalization with tensor conversion."""
    _transform = [
        A.Lambda(image=preprocessing_fn),
        ToTensorV2(),
    ]
    return A.Compose(_transform)

In [None]:
print("Loading trained model...")

try:
    # Instantiate the model with the same architecture as during training
    inference_model = smp.create_model(
        arch=MODEL_ARC,
        encoder_name=ENCODER,
        encoder_weights=None,
        in_channels=INPUT_CHANNELS_CONFIG,
        classes=1,
        activation=None, 
    )
    inference_model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    inference_model.to(DEVICE)
    inference_model.eval()
    print("Model loaded successfully and set to evaluation mode.")
except Exception as e:
    print(f"Error creating or loading model: {e}")
    print("Please ensure MODEL_ARC and ENCODER parameters in Cell 2 are correct.")
    raise e

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
print("Preprocessing function loaded.")

In [None]:
all_detected_properties = []
visualization_samples = []
NUM_SAMPLES_TO_VISUALIZE = 3 

# Get preprocessing pipelines
inference_augs_pipeline = get_inference_augs(IMG_HEIGHT, IMG_WIDTH)
preprocessing_pipeline = get_preprocessing(preprocessing_fn)

# Find all images in the folder
image_paths = glob.glob(os.path.join(IMAGE_PATH, '*'))
image_paths = [p for p in image_paths if p.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff'))]

if not image_paths:
    print(f"Error: No images found in '{IMAGE_PATH}'.")
else:
    print(f"\nFound {len(image_paths)} images to process. Starting batch inference...")
    # --- Loop through images ---
    for img_path in tqdm(image_paths, desc="Processing Images"):
        try:
            # 1. Load Original Image
            original_image_bgr = cv2.imread(img_path, cv2.IMREAD_COLOR)
            if original_image_bgr is None: continue
            original_image_rgb = cv2.cvtColor(original_image_bgr, cv2.COLOR_BGR2RGB)
            original_h, original_w = original_image_rgb.shape[:2]

            # 2. Preprocess Image (consistent with training)
            image_for_model = original_image_rgb.copy()

            augmented = inference_augs_pipeline(image=image_for_model)
            preprocessed = preprocessing_pipeline(image=augmented['image'])
            input_tensor = preprocessed['image'].unsqueeze(0).to(DEVICE, dtype=torch.float32)

            # 3. Perform Inference
            with torch.no_grad():
                pred_output = inference_model(input_tensor)

            # 4. Post-process Prediction
            pred_probs = torch.sigmoid(pred_output) # Apply sigmoid to get probabilities
            pred_prob_map = pred_probs.squeeze().cpu().numpy()
            pred_mask_binary = (pred_prob_map > PREDICTION_THRESHOLD).astype(np.uint8)
            pred_mask_resized = cv2.resize(pred_mask_binary, (original_w, original_h), interpolation=cv2.INTER_NEAREST)

            # 5. Analyze Predicted Mask
            output_cc = cv2.connectedComponentsWithStats(pred_mask_resized, 8, cv2.CV_32S)
            num_labels = output_cc[0]
            stats = output_cc[2]
            centroids = output_cc[3]

            for i in range(1, num_labels): # Skip background label 0
                track_props = {
                    "image_filename": os.path.basename(img_path), "track_id": i,
                    "area_px": stats[i, cv2.CC_STAT_AREA],
                    "centroid_x_px": round(centroids[i][0], 1), "centroid_y_px": round(centroids[i][1], 1),
                }
                if PIXEL_RESOLUTION_UM_PER_PX is not None:
                    track_props["area_um2"] = round(track_props["area_px"] * (PIXEL_RESOLUTION_UM_PER_PX**2), 2)
                all_detected_properties.append(track_props)

            # Save a few samples for visualization at the end
            if len(visualization_samples) < NUM_SAMPLES_TO_VISUALIZE:
                visualization_samples.append({
                    "original": original_image_rgb,
                    "predicted_mask": pred_mask_resized,
                    "filename": os.path.basename(img_path)
                })

        except Exception as e_loop:
            print(f"Error processing image {os.path.basename(img_path)}: {e_loop}")
            continue

    print("Batch inference complete.")

In [None]:
if not all_detected_properties:
  print("\nNo tracks were detected in any of the processed images.")
else:
    # --- Create and Display DataFrame ---
  properties_df = pd.DataFrame(all_detected_properties)
  print(f"\n--- Aggregated Results ---")

  print(f"Total tracks found across all images: {len(properties_df)}")
  print("Sample of detected track properties:")
  display(properties_df.head())

  plt.figure(figsize=(10, 6))
  plt.hist(properties_df['area_px'], bins=50, edgecolor='black') # You can adjust the number of bins
  plt.title('Distribution of Track Areas (in Pixels)')
  plt.xlabel('Area (pixels)')
  plt.ylabel('Frequency')
  plt.grid(axis='y', alpha=0.75)
  plt.show()

  # --- Save to CSV ---
  csv_output_path = os.path.join(OUTPUT_PATH, IMAGE_SPEC+"_detected_track_properties.csv")
  properties_df.to_csv(csv_output_path, index=False)
  print(f"\nFull results saved to: {csv_output_path}")

In [None]:
if not visualization_samples:
    print("\nNo samples were saved for visualization.")
else:
    print("\n--- Example Visualizations with Overlays ---")
    for sample in visualization_samples:
        original_img = sample["original"]
        pred_mask = sample["predicted_mask"]
        filename = sample["filename"]

        # Create overlay image
        overlay = original_img.copy()
        contours, _ = cv2.findContours(pred_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(overlay, contours, -1, (0, 255, 0), 1) # Draw contours in green

        # Display the result
        plt.figure(figsize=(7, 7))
        plt.imshow(overlay)
        plt.title(f"Predicted Tracks on: {filename}")
        plt.axis('off')
        plt.show()