# 3D CNN Nucleus State Inference Example
This notebook demonstrates how to run inference using the `predict.py` script from the ncnn4 model. The script supports batch prediction, single sample prediction, and full volume prediction with optional nuclei selection.


In [None]:
import numpy as np
from predict import load_model, process_single_sample_by_np_volumes

# Load your numpy arrays (replace with your own loading logic)
vol_dict = {
    't-1': np.load('t_minus_1.npy'),
    't': np.load('t.npy'),
    't+1': np.load('t_plus_1.npy'),
    'mask': np.load('mask.npy'),
}
model = load_model('training_outputs/best_model.pth')
model.eval()
import argparse
args = argparse.Namespace(output_dir='./analysis_output', save_analysis=False, verbose=True)
result = process_single_sample_by_np_volumes(vol_dict, model, args)
print(result)

In [None]:
import numpy as np
from predict import load_model, handle_batch_folder_prediction

# Load the model
model = load_model('training_outputs/best_model.pth')
model.eval()
# List all the directory paths for batch prediction (this is the output of the preperation setps)

batch_dirs = [
    '/data/nuclei_state_dataset/v3/new_daughter/230212_stack6_frame_025_nucleus_011_count_15',
    '/mnt/home/dchhantyal/3d-cnn-classification/data/nuclei_state_dataset/v3/mitotic/230212_stack6_frame_027_nucleus_014_count_16',
]

import argparse

args = argparse.Namespace(
    folder_path=batch_dirs,
    save_analysis=False,
    verbose=True
)
results = handle_batch_folder_prediction(model, args)
print("Batch prediction results:", results)

In [None]:
# Example: Programmatic inference using predict.py functions
from predict import (
    load_model,
    handle_full_timestamp_prediction,
)
import argparse

# Set up model path and volume paths
model_path = '/mnt/home/dchhantyal/3d-cnn-classification/model/ncnn4/training_outputs/no-aug/best_model.pth'
volume_paths = [
    '/mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_images/nuclei_reg8_29.tif',
    '/mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_images/nuclei_reg8_30.tif',
    '/mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_images/nuclei_reg8_31.tif',
    '/mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_label_images/label_reg8_30.tif',
]

# Load the trained model
model = load_model(model_path, verbose=False)
model.eval()


# Example: Predict specific nuclei by IDs
nuclei_ids = [16, 17]  

# To run full prediction and get results, you would need to construct an argparse.Namespace or refactor the script for direct function calls.

results = handle_full_timestamp_prediction(
    model=model,
    args=argparse.Namespace(
        model_path=model_path,
        volumes=volume_paths,
        full_timestamp=True,
        nuclei_ids=None,
    ),
)
print("Prediction results:", results)

Prediction results: [{'sample': 'sample_20250731_110036', 'true_class': None, 'predicted_class': 'stable', 'index': 2, 'confidence': 1.0, 'correct': None, 'processing_time': 2.589352}, {'sample': 'sample_20250731_110039', 'true_class': None, 'predicted_class': 'stable', 'index': 2, 'confidence': 0.9992713332176208, 'correct': None, 'processing_time': 3.199789}, {'sample': 'sample_20250731_110042', 'true_class': None, 'predicted_class': 'stable', 'index': 2, 'confidence': 0.8772227168083191, 'correct': None, 'processing_time': 5.10013}, {'sample': 'sample_20250731_110047', 'true_class': None, 'predicted_class': 'stable', 'index': 2, 'confidence': 0.9581482410430908, 'correct': None, 'processing_time': 4.30028}, {'sample': 'sample_20250731_110052', 'true_class': None, 'predicted_class': 'stable', 'index': 2, 'confidence': 0.8139852285385132, 'correct': None, 'processing_time': 1.801851}, {'sample': 'sample_20250731_110053', 'true_class': None, 'predicted_class': 'stable', 'index': 2, 'co

In [None]:
import numpy as np
from predict import load_model, handle_single_volume_prediction

# Load the model
model = load_model('training_outputs/best_model.pth')
model.eval()
# List volume of cropped t-1, t, t+1, and mask, note: ONLY .tif is accepted
volumes = [
    '/mnt/home/dchhantyal/3d-cnn-classification/data/nuclei_state_dataset/v3/mitotic/230212_stack6_frame_027_nucleus_014_count_16/t-1/raw_cropped.tif',
    '/mnt/home/dchhantyal/3d-cnn-classification/data/nuclei_state_dataset/v3/mitotic/230212_stack6_frame_027_nucleus_014_count_16/t/raw_cropped.tif',
    '/mnt/home/dchhantyal/3d-cnn-classification/data/nuclei_state_dataset/v3/mitotic/230212_stack6_frame_027_nucleus_014_count_16/t+1/raw_cropped.tif',
    '/mnt/home/dchhantyal/3d-cnn-classification/data/nuclei_state_dataset/v3/mitotic/230212_stack6_frame_027_nucleus_014_count_16/t/binary_label_cropped.tif'
]

import argparse

args = argparse.Namespace(
    volumes=volumes,
    full_timestamp=False,
    nuclei_ids=None,  # Specify nuclei IDs if needed
    save_analysis=False,
    verbose=True
)

results = handle_single_volume_prediction(model, args)
print("Single volume prediction results:", results)


In [1]:

# Example: Predict specific nuclei by IDs (e.g., 16, 17) in the same volume
!srun python predict.py --model_path /mnt/home/dchhantyal/3d-cnn-classification/model/ncnn4/training_outputs/no-aug/best_model.pth --volumes /mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_images/nuclei_reg8_29.tif /mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_images/nuclei_reg8_30.tif /mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_images/nuclei_reg8_31.tif /mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_label_images/label_reg8_30.tif --full_timestamp --nuclei_ids "16,17"


🎉 BATCH PROCESSING COMPLETE
 1. sample_20250730_111342              | True: UNKNOWN      → Pred: STABLE       (83.9%) ❓
 2. sample_20250730_111342              | True: UNKNOWN      → Pred: STABLE       (83.3%) ❓

📊 ACCURACY SUMMARY
Total samples processed successfully: 2


In [None]:
!srun python predict.py --model_path /mnt/home/dchhantyal/3d-cnn-classification/model/ncnn4/training_outputs/no-aug/best_model.pth --volumes /mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_images/nuclei_reg8_29.tif /mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_images/nuclei_reg8_30.tif /mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_images/nuclei_reg8_31.tif /mnt/home/dchhantyal/3d-cnn-classification/raw-data/230212_stack6/registered_label_images/label_reg8_30.tif --full_timestamp 


🎉 BATCH PROCESSING COMPLETE
 1. sample_20250730_111700              | True: UNKNOWN      → Pred: STABLE       (100.0%) ❓
 2. sample_20250730_111700              | True: UNKNOWN      → Pred: STABLE       (99.9%) ❓
 3. sample_20250730_111700              | True: UNKNOWN      → Pred: STABLE       (87.7%) ❓
 4. sample_20250730_111700              | True: UNKNOWN      → Pred: STABLE       (95.8%) ❓
 5. sample_20250730_111700              | True: UNKNOWN      → Pred: STABLE       (81.4%) ❓
 6. sample_20250730_111700              | True: UNKNOWN      → Pred: STABLE       (93.7%) ❓
 7. sample_20250730_111700              | True: UNKNOWN      → Pred: STABLE       (89.9%) ❓
 8. sample_20250730_111700              | True: UNKNOWN      → Pred: STABLE       (72.5%) ❓
 9. sample_20250730_111700              | True: UNKNOWN      → Pred: STABLE       (92.0%) ❓
10. sample_20250730_111700              | True: UNKNOWN      → Pred: STABLE       (80.3%) ❓
11. sample_20250730_111700              | True: UN