In [None]:
# Import required libraries
import sys
import os
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

# Add the cell_generator module to path
sys.path.insert(0, '/home/lionb/cell_generator')

# Import cell_generator modules
from dataset import DataGen
import global_vars as gv
from mg_analyzer import analyze_th
from utils import *

print("Imports successful!")

In [None]:
# Set up file paths
data_csv_path = "/mnt/new_groups/assafza_group/assafza/lion_models_clean/train_test_list/Nuclear-envelope/test.csv"
model_path = "/mnt/new_groups/assafza_group/assafza/lion_models_clean/models/mg_model_ne_13_05_24_1.0"

# Verify paths exist
print(f"Data CSV exists: {os.path.exists(data_csv_path)}")
print(f"Model path exists: {os.path.exists(model_path)}")

# Load the CSV to check data
data_df = pd.read_csv(data_csv_path)
print(f"\nDataset shape: {data_df.shape}")
print(f"Dataset columns: {data_df.columns.tolist()}")
print(f"First few rows:\n{data_df.head()}")

In [None]:
# Override global variables for Nuclear-envelope organelle
gv.organelle = "Nuclear-envelope"
gv.model_path = model_path
gv.train_ds_path = data_csv_path
gv.test_ds_path = data_csv_path

# Set input and target channels
gv.input = "channel_signal"
gv.target = "channel_target"

# Set model parameters
gv.model_type = "MG"
gv.batch_size = 4
gv.patch_size = (32, 128, 128, 1)
gv.number_epochs = 100

# Display the updated global variables
print("Updated global variables:")
print(f"  organelle: {gv.organelle}")
print(f"  model_path: {gv.model_path}")
print(f"  model_type: {gv.model_type}")
print(f"  batch_size: {gv.batch_size}")
print(f"  patch_size: {gv.patch_size}")
print(f"  input: {gv.input}")
print(f"  target: {gv.target}")

In [None]:
# Create dataset from CSV file
print(f"Creating dataset from: {data_csv_path}")
print(f"Using input column: '{gv.input}' and target column: '{gv.target}'")

# Create DataGen object for the test dataset
test_dataset = DataGen(
    data_csv_path,
    input_col=gv.input,
    target_col=gv.target,
    batch_size=gv.batch_size,
    num_batches=4,
    patch_size=gv.patch_size,
    min_precentage=0.0,
    max_precentage=1.0,
    augment=False,
    norm_type="std",
    delete_cahce=True
)

print(f"Dataset created successfully!")
print(f"Number of images in dataset: {len(test_dataset.list_of_image_keys)}")

In [None]:
# Load the pre-trained MaskInterpreter model
print(f"Loading model from: {model_path}")

try:
    model = keras.models.load_model(model_path)
    print("Model loaded successfully!")
    print(f"Model summary:")
    model.summary()
except Exception as e:
    print(f"Error loading model: {e}")
    model = None

In [None]:
# Run analyze_th to generate predictions and importance masks
# This function will:
# 1. Generate predictions on original images
# 2. Generate importance masks
# 3. Apply thresholding at different levels
# 4. Compute mask efficacy (PCC between predictions)
# 5. Save visualizations and results

print("Running analyze_th for predictions and importance mask generation...")
print("This may take a while depending on the dataset size...")

# Use only first 2 images for demonstration (use range(len(test_dataset.list_of_image_keys)) for all)
num_images_to_analyze = 2
images_to_analyze = range(min(num_images_to_analyze, len(test_dataset.list_of_image_keys)))

try:
    analyze_th(
        dataset=test_dataset,
        mode="regular",  # Mode options: "agg", "loo", "mask", "regular"
        manual_th="full",  # Full importance mask without thresholding
        save_image=True,  # Save visualizations
        save_histo=False,  # Save importance mask histograms
        weighted_pcc=False,  # Use regular PCC (not weighted)
        model_path=model_path,
        model=model,
        compound=None,  # No compound/drug perturbation
        images=images_to_analyze,  # Which images to analyze
        noise_scale=1.5,  # Noise scale for perturbation
        save_results=True,
        results_save_path=None  # Will save in model_path/predictions/
    )
    print("analyze_th completed successfully!")
except Exception as e:
    print(f"Error during analyze_th: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Create visualization function to display input, prediction, and explanation masks
def visualize_results(dataset, image_index, model, patch_size, save_path=None):
    \"\"\"\
    Visualize all channels of input image together with prediction and explanation mask.
    
    Args:
        dataset: DataGen object
        image_index: Index of image to visualize
        model: Loaded MaskInterpreter model
        patch_size: Patch size used by model
        save_path: Optional path to save figure
    \"\"\"\
    from cell_imaging_utils.datasets_metadata.table.datasetes_metadata_csv import DatasetMetadataSCV
    from cell_imaging_utils.image.image_utils import ImageUtils
    import cv2
    
    # Load image
    print(f"Loading image {image_index}...")
    image_row = dataset.get_image_row(image_index)
    path_tiff = image_row[dataset.image_path_col]
    
    # Load input and target channels
    input_image = ImageUtils.imread_tiff(path_tiff, image_row[dataset.input_col])
    target_image = ImageUtils.imread_tiff(path_tiff, image_row[dataset.target_col])
    
    # Normalize
    input_image = (input_image - np.mean(input_image)) / (np.std(input_image) + 1e-6)
    target_image = (target_image - np.mean(target_image)) / (np.std(target_image) + 1e-6)
    
    print(f"Input shape: {input_image.shape}, Target shape: {target_image.shape}")
    
    # Get predictions and importance mask from model
    print("Generating predictions and importance masks...")
    
    # Process image in patches
    from mg_analyzer import collect_patchs, predict, assemble_image
    
    center_xy = [312, 462]
    margin = [192, 256]
    xy_step = 64
    z_step = 16
    
    px_start = center_xy[0] - margin[0] - xy_step
    py_start = center_xy[1] - margin[1] - xy_step
    pz_start = 0
    px_end = center_xy[0] + margin[0] + xy_step
    py_end = center_xy[1] + margin[1] + xy_step
    pz_end = input_image.shape[0]
    
    # Collect patches
    input_patchs = collect_patchs(px_start, py_start, pz_start, px_end, py_end, pz_end, 
                                  input_image, patch_size, xy_step, z_step)
    
    # Predict
    predictions, importance_masks = predict(model, input_patchs, batch_size=4)
    
    # Assemble predictions and masks back to full image
    pred_image = assemble_image(predictions, (pz_end-pz_start, px_end-px_start, py_end-py_start, 1),
                                patch_size, xy_step, z_step)
    mask_image = assemble_image(importance_masks, (pz_end-pz_start, px_end-px_start, py_end-py_start, 1),
                                patch_size, xy_step, z_step)
    
    # Select middle z-slice for visualization
    z_mid = input_image.shape[0] // 2
    
    fig = plt.figure(figsize=(16, 12))
    gs = GridSpec(2, 3, figure=fig, hspace=0.3, wspace=0.3)
    
    # Row 1: Input, Target, Prediction
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(input_image[z_mid, :, :, 0], cmap='viridis')
    ax1.set_title('Input Signal')
    ax1.axis('off')
    
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.imshow(target_image[z_mid, :, :, 0], cmap='hot')
    ax2.set_title('Target')
    ax2.axis('off')
    
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.imshow(pred_image[z_mid, :, :, 0], cmap='hot')
    ax3.set_title('Prediction')
    ax3.axis('off')
    
    # Row 2: Importance mask, Input overlay with mask, Prediction overlay with mask
    ax4 = fig.add_subplot(gs[1, 0])
    im_mask = ax4.imshow(mask_image[z_mid, :, :, 0], cmap='cool')
    ax4.set_title('Importance Mask')
    ax4.axis('off')
    plt.colorbar(im_mask, ax=ax4)
    
    ax5 = fig.add_subplot(gs[1, 1])
    ax5.imshow(input_image[z_mid, :, :, 0], cmap='gray', alpha=0.7)
    ax5.imshow(mask_image[z_mid, :, :, 0], cmap='Reds', alpha=0.4)
    ax5.set_title('Input + Mask Overlay')
    ax5.axis('off')
    
    ax6 = fig.add_subplot(gs[1, 2])
    ax6.imshow(target_image[z_mid, :, :, 0], cmap='gray', alpha=0.7)
    ax6.imshow(mask_image[z_mid, :, :, 0], cmap='Reds', alpha=0.4)
    ax6.set_title('Target + Mask Overlay')
    ax6.axis('off')
    
    plt.suptitle(f'Image {image_index} - Z-slice {z_mid}', fontsize=16, fontweight='bold')
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Figure saved to {save_path}")
    
    plt.show()
    
    return pred_image, mask_image

print("Visualization function created successfully!")

In [None]:
# Visualize results for the first image
if model is not None and len(test_dataset.list_of_image_keys) > 0:
    try:
        pred_image, mask_image = visualize_results(
            dataset=test_dataset,
            image_index=0,
            model=model,
            patch_size=gv.patch_size,
            save_path=None  # Set to a path to save figure, e.g., "/tmp/prediction_viz.png"
        )
        print("Visualization completed!")
    except Exception as e:
        print(f"Error during visualization: {e}")
        import traceback
        traceback.print_exc()
else:
    print("Model not loaded or no images available!")

In [None]:
# Load and display results from analyze_th
results_dir = f"{model_path}/predictions"

if os.path.exists(results_dir):
    print(f"Results directory: {results_dir}")
    
    # Check for results CSV files
    pcc_results_path = f"{results_dir}/pcc_resuls.csv"
    mask_size_results_path = f"{results_dir}/mask_size_resuls.csv"
    
    if os.path.exists(pcc_results_path):
        pcc_results = pd.read_csv(pcc_results_path)
        print("\nPCC Results (Mask Efficacy):")
        print(pcc_results)
    
    if os.path.exists(mask_size_results_path):
        mask_size_results = pd.read_csv(mask_size_results_path)
        print("\nMask Size Results:")
        print(mask_size_results)
    
    # List generated images
    print(f"\nGenerated files in {results_dir}:")
    for item in os.listdir(results_dir):
        item_path = os.path.join(results_dir, item)
        if os.path.isdir(item_path):
            print(f"  [DIR] {item}")
        else:
            print(f"  [FILE] {item}")
else:
    print(f"Results directory not found: {results_dir}")