In [1]:
# Auto-reload modules before executing code
%load_ext autoreload
%autoreload 2

In [3]:
# Import required libraries
import sys
import os
import pandas as pd
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import init_env_vars

# Add the cell_generator module to path
sys.path.insert(0, os.path.join(os.environ['REPO_LOCAL_PATH'], 'cell_generator'))

# Import cell_generator modules
from dataset import DataGen
import global_vars as gv
from mg_analyzer import analyze_th
from utils import *

print("Imports successful!")

Imports successful!


In [None]:
# Set environment variables for the notebook session
os.environ['REPO_LOCAL_PATH'] = "/home/lionb"
os.environ['DATA_PATH'] = "/mnt/new_groups/assafza_group/assafza/lion_models_clean/example_data/train_test_list/"
os.environ['MODELS_PATH'] = "/mnt/new_groups/assafza_group/assafza/lion_models_clean/models/"

print("Environment variables set:")
print(f"  REPO_LOCAL_PATH: {os.environ['REPO_LOCAL_PATH']}")
print(f"  DATA_PATH: {os.environ['DATA_PATH']}")
print(f"  MODELS_PATH: {os.environ['MODELS_PATH']}")

Environment variables set:
  REPO_LOCAL_PATH: /home/lionb


KeyError: 'DATA_PATH'

In [None]:
# Run replace_dir.sh to update CSV file paths
# !cd ${REPO_LOCAL_PATH}/cell_generator && bash replace_dir.sh

In [None]:
# Set up file paths
data_csv_path = os.path.join(os.environ.get('DATA_PATH'), 'Nuclear-envelope/image_list_test.csv')
model_path = os.path.join(os.environ.get('MODELS_PATH'), 'mg_model_ne_13_05_24_1.0')

# Verify paths exist
print(f"Data CSV exists: {os.path.exists(data_csv_path)}")
print(f"Model path exists: {os.path.exists(model_path)}")

# Load the CSV to check data
data_df = pd.read_csv(data_csv_path)
print(f"\nDataset shape: {data_df.shape}")
print(f"Dataset columns: {data_df.columns.tolist()}")
print(f"First few rows:\n{data_df.head()}")

In [None]:
# Override global variables for Nuclear-envelope organelle
gv.organelle = "Nuclear-envelope"
gv.model_path = model_path
gv.test_ds_path = data_csv_path

# Set input and target channels
gv.input = "channel_signal"
gv.target = "channel_target"

# Set model parameters
gv.model_type = "MG"
gv.batch_size = 4
gv.patch_size = (32, 128, 128, 1)

# Display the updated global variables
print("Updated global variables:")
print(f"  organelle: {gv.organelle}")
print(f"  model_path: {gv.model_path}")
print(f"  model_type: {gv.model_type}")
print(f"  batch_size: {gv.batch_size}")
print(f"  patch_size: {gv.patch_size}")
print(f"  input: {gv.input}")
print(f"  target: {gv.target}")

In [None]:
# Create dataset from CSV file
print(f"Creating dataset from: {data_csv_path}")
print(f"Using input column: '{gv.input}' and target column: '{gv.target}'")

# Create DataGen object for the test dataset
test_dataset = DataGen(
    gv.test_ds_path,
    input_col=gv.input,
    target_col=gv.target,
    batch_size=gv.batch_size,
    num_batches=32,
    patch_size=gv.patch_size,
    min_precentage=0.0,
    max_precentage=1.0,
    augment=False,
    norm_type="std",
    delete_cahce=True
)

print(f"Dataset created successfully!")

In [None]:
# Load the pre-trained MaskInterpreter model
print(f"Loading model from: {model_path}")

try:
    model = keras.models.load_model(model_path)
    print("Model loaded successfully!")
    print(f"Model summary:")
    model.summary()
except Exception as e:
    print(f"Error loading model: {e}")
    model = None

In [None]:
# Run analyze_th to generate predictions and importance masks
# This function will:
# 1. Generate predictions on original images
# 2. Generate importance masks
# 3. Apply thresholding at different levels
# 4. Compute mask efficacy (PCC between predictions)
# 5. Save visualizations and results

print("Running analyze_th for predictions and importance mask generation...")
print("This may take a while depending on the dataset size...")

# Use only first 2 images for demonstration (use range(len(test_dataset.list_of_image_keys)) for all)
num_images_to_analyze = 2
images_to_analyze = range(num_images_to_analyze)

try:
    analyze_th(
        dataset=test_dataset,
        mode="regular",  # Mode options: "agg", "loo", "mask", "regular"
        manual_th="full",  # Full importance mask without thresholding
        save_image=True,  # Save visualizations
        save_histo=False,  # Save importance mask histograms
        weighted_pcc=False,  # Use regular PCC (not weighted)
        model_path=model_path,
        model=model,
        compound=None,  # No compound/drug perturbation
        images=images_to_analyze,  # Which images to analyze
        noise_scale=1.0,  # Noise scale for perturbation
        save_results=True,
        results_save_path=None  # Will save in model_path/predictions/
    )
    print("analyze_th completed successfully!")
except Exception as e:
    print(f"Error during analyze_th: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Function to visualize results from saved predictions
def auto_balance(image):
    """Auto balance the image similar to ImageJ's Auto Contrast function."""
    # Convert to float32 to avoid issues with overflow/underflow
    image = image.astype(np.float32)
    
    # Calculate the 0.1th and 99.9th percentiles
    plow, phigh = np.percentile(image, (0.1, 99.9))
    
    # Stretch the values to the full range [0, 1]
    image = np.clip((image - plow) / (phigh - plow), 0, 1)
    
    return image

def visualize_results(model_path, image_index=0, z_slice=None):
    """
    Load and visualize saved prediction results from model/predictions directory.
    
    Args:
        model_path: Path to the model directory
        image_index: Index of the image to visualize
        z_slice: Specific z-slice to display (None for middle slice)
    """
    from cell_imaging_utils.image.image_utils import ImageUtils
    
    predictions_dir = f"{model_path}/predictions/{image_index}"
    
    # Find prediction files for the specified image
    input_file = f"{predictions_dir}/input_{image_index}.tiff"
    target_file = f"{predictions_dir}/target_{image_index}.tiff"
    pred_file = f"{predictions_dir}/unet_prediction_{image_index}.tiff"
    mask_file = f"{predictions_dir}/full/mask_{image_index}.tiff"
    noisy_input_file = f"{predictions_dir}/full/noisy_input_{image_index}.tiff"
    noisy_pred_file = f"{predictions_dir}/full/noisy_unet_prediction_{image_index}.tiff"
    
    # Check if files exist
    if not os.path.exists(pred_file):
        print(f"Prediction file not found: {pred_file}")
        return None, None
    
    # Load TIFF arrays using ImageUtils
    print("Loading images...")
    input_image = ImageUtils.imread(input_file) if os.path.exists(input_file) else None
    target_image = ImageUtils.imread(target_file) if os.path.exists(target_file) else None
    pred_image = ImageUtils.imread(pred_file)
    mask_image = ImageUtils.imread(mask_file) if os.path.exists(mask_file) else None
    noisy_input_image = ImageUtils.imread(noisy_input_file) if os.path.exists(noisy_input_file) else None
    noisy_pred_image = ImageUtils.imread(noisy_pred_file) if os.path.exists(noisy_pred_file) else None
    
    print(f"Loaded prediction shape: {pred_image.shape}")
    if mask_image is not None:
        print(f"Loaded mask shape: {mask_image.shape}")
    
    # Ensure images have channel dimension
    if len(pred_image.shape) == 3:
        pred_image = pred_image[..., np.newaxis]
    if input_image is not None and len(input_image.shape) == 3:
        input_image = input_image[..., np.newaxis]
    if target_image is not None and len(target_image.shape) == 3:
        target_image = target_image[..., np.newaxis]
    if mask_image is not None and len(mask_image.shape) == 3:
        mask_image = mask_image[..., np.newaxis]
    if noisy_input_image is not None and len(noisy_input_image.shape) == 3:
        noisy_input_image = noisy_input_image[..., np.newaxis]
    if noisy_pred_image is not None and len(noisy_pred_image.shape) == 3:
        noisy_pred_image = noisy_pred_image[..., np.newaxis]
    
    # Determine z-slice to display
    if z_slice is None:
        z_slice = pred_image.shape[0] // 2
    
    # Create visualization
    fig = plt.figure(figsize=(18, 12))
    
    # Row 1: Original images (grayscale with auto-balance)
    ax1 = fig.add_subplot(3, 3, 1)
    if input_image is not None:
        img_balanced = auto_balance(input_image[z_slice, :, :, 0])
        ax1.imshow(img_balanced, cmap='gray')
    ax1.set_title('Input Signal')
    ax1.axis('off')
    
    ax2 = fig.add_subplot(3, 3, 2)
    if target_image is not None:
        img_balanced = auto_balance(target_image[z_slice, :, :, 0])
        ax2.imshow(img_balanced, cmap='gray')
    ax2.set_title('Target')
    ax2.axis('off')
    
    ax3 = fig.add_subplot(3, 3, 3)
    img_balanced = auto_balance(pred_image[z_slice, :, :, 0])
    ax3.imshow(img_balanced, cmap='gray')
    ax3.set_title('Prediction')
    ax3.axis('off')
    
    # Row 2: Importance mask (heatmap) and overlays
    if mask_image is not None:
        ax4 = fig.add_subplot(3, 3, 4)
        mask_balanced = auto_balance(mask_image[z_slice, :, :, 0])
        im = ax4.imshow(mask_balanced, cmap='hot')
        ax4.set_title('Importance Mask')
        ax4.axis('off')
        plt.colorbar(im, ax=ax4, fraction=0.046)
        
        ax5 = fig.add_subplot(3, 3, 5)
        if input_image is not None:
            input_balanced = auto_balance(input_image[z_slice, :, :, 0])
            mask_balanced = auto_balance(mask_image[z_slice, :, :, 0])
            ax5.imshow(input_balanced, cmap='gray', alpha=0.7)
            ax5.imshow(mask_balanced, cmap='Reds', alpha=0.4)
        ax5.set_title('Input + Mask Overlay')
        ax5.axis('off')
        
        ax6 = fig.add_subplot(3, 3, 6)
        if target_image is not None:
            target_balanced = auto_balance(target_image[z_slice, :, :, 0])
            mask_balanced = auto_balance(mask_image[z_slice, :, :, 0])
            ax6.imshow(target_balanced, cmap='gray', alpha=0.7)
            ax6.imshow(mask_balanced, cmap='Reds', alpha=0.4)
        ax6.set_title('Target + Mask Overlay')
        ax6.axis('off')
    
    # Row 3: Noisy versions (grayscale with auto-balance) and difference map (heatmap)
    if noisy_input_image is not None:
        ax7 = fig.add_subplot(3, 3, 7)
        img_balanced = auto_balance(noisy_input_image[z_slice, :, :, 0])
        ax7.imshow(img_balanced, cmap='gray')
        ax7.set_title('Noisy Input')
        ax7.axis('off')
    
    if noisy_pred_image is not None:
        ax8 = fig.add_subplot(3, 3, 8)
        img_balanced = auto_balance(noisy_pred_image[z_slice, :, :, 0])
        ax8.imshow(img_balanced, cmap='gray')
        ax8.set_title('Noisy Prediction')
        ax8.axis('off')
    
    # Difference map (keep as heatmap)
    if noisy_pred_image is not None:
        ax9 = fig.add_subplot(3, 3, 9)
        pred_balanced = auto_balance(pred_image[z_slice, :, :, 0])
        noisy_pred_balanced = auto_balance(noisy_pred_image[z_slice, :, :, 0])
        diff = pred_balanced - noisy_pred_balanced
        im_diff = ax9.imshow(diff, cmap='RdBu_r', vmin=-np.abs(diff).max(), vmax=np.abs(diff).max())
        ax9.set_title('Prediction Difference')
        ax9.axis('off')
        plt.colorbar(im_diff, ax=ax9, fraction=0.046)
    
    plt.suptitle(f'Image {image_index} - Z-slice {z_slice}/{pred_image.shape[0]}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    return pred_image, mask_image

# Visualize results for the first image
try:
    pred_image, mask_image = visualize_results(
        model_path=model_path,
        image_index=0,
        z_slice=19  # None for middle slice, or specify a number like 15
    )
    if pred_image is not None:
        print("Visualization completed!")
except Exception as e:
    print(f"Error during visualization: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# Load and display results from analyze_th
results_dir = f"{model_path}/predictions"

if os.path.exists(results_dir):
    print(f"Results directory: {results_dir}")
    
    # Check for results CSV files
    pcc_results_path = f"{results_dir}/pcc_resuls.csv"
    mask_size_results_path = f"{results_dir}/mask_size_resuls.csv"
    
    if os.path.exists(pcc_results_path):
        pcc_results = pd.read_csv(pcc_results_path)
        print("\nPCC Results (Mask Efficacy):")
        print(pcc_results)
    
    if os.path.exists(mask_size_results_path):
        mask_size_results = pd.read_csv(mask_size_results_path)
        print("\nMask Size Results:")
        print(mask_size_results)
    
    # List generated images
    print(f"\nGenerated files in {results_dir}:")
    for item in os.listdir(results_dir):
        item_path = os.path.join(results_dir, item)
        if os.path.isdir(item_path):
            print(f"  [DIR] {item}")
        else:
            print(f"  [FILE] {item}")
else:
    print(f"Results directory not found: {results_dir}")