In [14]:
import numpy as np
import pickle as pkl
import torch


In [15]:
s1 = np.load('/orcd/data/omarabu/001/tanush20/data/MERFISH/S1R1.npz', allow_pickle=True)
with open('/orcd/data/omarabu/001/njwfish/counting_flows/results/merfish/x_0_generated.pkl', 'rb') as f:
    predicted_counts = pkl.load(f)
combined_predicted_counts = torch.vstack(predicted_counts)

In [16]:
# Convert predicted counts to s1 format
print("Converting predicted counts to s1 format...")

# Get the original spots structure from s1
imgs = s1['imgs'][:]
original_spots = s1['spots']
n_spots = len(original_spots)

# Create new spots array with consecutive indices
new_spots = []
current_index = 0
cells_to_remove = []
total_cells = 0
for i in range(n_spots):
    # Get number of cells in this spot from original data
    num_cells_in_spot = 0
    for cell in original_spots[i]:
        if max(imgs[cell].shape)>256:
            cells_to_remove.append(cell)
        else:
            num_cells_in_spot += 1
    total_cells += num_cells_in_spot
            
    # Create array of consecutive indices for this spot
    spot_indices = np.arange(current_index, current_index + num_cells_in_spot)
    new_spots.append(spot_indices)
    
    # Update current index for next spot
    current_index += num_cells_in_spot

print(f"Created {len(new_spots)} spots")
print(f"Total cells: {current_index}")
print(f"Combined predicted counts shape: {combined_predicted_counts.shape}")

# Verify the conversion
print("\nVerification:")
print(f"Original spots[0]: {original_spots[0][:5]}... (length: {len(original_spots[0])})")
print(f"New spots[0]: {new_spots[0][:5]}... (length: {len(new_spots[0])})")
print(f"Original spots[1]: {original_spots[1][:5]}... (length: {len(original_spots[1])})")
print(f"New spots[1]: {new_spots[1][:5]}... (length: {len(new_spots[1])})")


Converting predicted counts to s1 format...
Created 4358 spots
Total cells: 73613
Combined predicted counts shape: torch.Size([157669, 649])

Verification:
Original spots[0]: [30684 30690 30692 30693 30698]... (length: 5)
New spots[0]: [0 1 2 3 4]... (length: 5)
Original spots[1]: [30683 30688 30694 30697 30699]... (length: 11)
New spots[1]: [5 6 7 8 9]... (length: 11)


In [17]:
# Convert torch tensor to numpy if needed
if isinstance(combined_predicted_counts, torch.Tensor):
    combined_predicted_counts_np = combined_predicted_counts.cpu().numpy()
else:
    combined_predicted_counts_np = combined_predicted_counts
combined_predicted_counts_np = combined_predicted_counts_np[:total_cells]
# Create the new data structure in s1 format
predicted_s1_format = {
    'predicted_counts': combined_predicted_counts_np,  # Combined counts matrix
    'spots': new_spots,  # List of arrays with cell indices for each spot
    'imgs': s1['imgs'],  # Copy other fields from original s1
    'x_um': s1['x_um'],
    'y_um': s1['y_um']
}

print("Created predicted_s1_format with keys:", list(predicted_s1_format.keys()))
print(f"Counts shape: {predicted_s1_format['predicted_counts'].shape}")
print(f"Number of spots: {len(predicted_s1_format['spots'])}")


Created predicted_s1_format with keys: ['predicted_counts', 'spots', 'imgs', 'x_um', 'y_um']
Counts shape: (73613, 649)
Number of spots: 4358


In [20]:
new_spots = np.array(new_spots, dtype=object)

np.savez("/orcd/data/omarabu/001/tanush20/data/MERFISH/S1R1_predicted_counts_2.npz",
        spots = new_spots,
        counts = predicted_s1_format['predicted_counts'])

In [22]:
# Load both datasets for evaluation
print("Loading datasets for evaluation...")
true_data = np.load('/orcd/data/omarabu/001/tanush20/data/MERFISH/S1R1.npz', allow_pickle=True)
predicted_data = np.load('/orcd/data/omarabu/001/tanush20/data/MERFISH/S1R1_predicted_counts_2.npz', allow_pickle=True)

print(f"True data keys: {list(true_data.keys())}")
print(f"Predicted data keys: {list(predicted_data.keys())}")
print(f"True counts shape: {true_data['counts'].shape}")
print(f"Predicted counts shape: {predicted_data['counts'].shape}")
print(f"Number of spots: {len(true_data['spots'])}")


Loading datasets for evaluation...
True data keys: ['imgs', 'counts', 'spots', 'x_um', 'y_um', 'annotations', 'visium_simulated', 'cell_type_counts', 'n_cells']
Predicted data keys: ['spots', 'counts']
True counts shape: (78329, 649)
Predicted counts shape: (73613, 649)
Number of spots: 4358


In [25]:
n_spots

4358

In [28]:
# Evaluation function
def evaluate_predictions(true_data, predicted_data, imgs, max_img_size=256):
    """
    Evaluate predicted counts against true counts and spot means
    
    Args:
        true_data: npz file with true data
        predicted_data: npz file with predicted data
        max_img_size: maximum image size to consider (default 256)
    
    Returns:
        dict with MSE results
    """
    true_counts = true_data['counts']
    predicted_counts = predicted_data['counts']
    true_spots = true_data['spots']
    predicted_spots = predicted_data['spots']
    
    n_spots = len(true_spots)
    
    # Store MSE results for each spot
    mse_predicted_vs_true = []
    mse_predicted_vs_spot_mean = []
    mse_true_vs_spot_mean = []
    valid_spots = 0
    
    print(f"Evaluating {n_spots} spots...")
    
    for spot_idx in range(n_spots):
        if spot_idx % 500 == 0:
            print(f"Processing spot {spot_idx}/{n_spots}")
        
        # Get cell indices for this spot
        true_cell_indices = np.array(true_spots[spot_idx]).astype(int)
        predicted_cell_indices = np.array(predicted_spots[spot_idx]).astype(int)
        
        # Get counts for this spot
        true_spot_counts = true_counts[true_cell_indices]
        predicted_spot_counts = predicted_counts[predicted_cell_indices]
        
        # Filter cells based on image size (only for true data since predicted is already filtered)
        valid_cells = []
        for i,cell_idx in enumerate(true_cell_indices):
            cell_img = imgs[cell_idx]
            if max(cell_img.shape) <= 256:  # Fixed: should be <= not >
                valid_cells.append(i)
        if len(valid_cells) == 0:
            continue  # Skip spots with no valid cells
        
        # Filter true counts to only valid cells (predicted is already filtered)
        true_spot_counts_valid = true_spot_counts[valid_cells]
        predicted_spot_counts_valid = predicted_spot_counts  # Already filtered
        
        # Calculate spot mean from true counts (filtered)
        spot_mean = np.mean(true_spot_counts_valid, axis=0)
        
        # Calculate MSE: predicted vs true
        mse_pred_true = np.mean((predicted_spot_counts_valid - true_spot_counts_valid) ** 2)
        mse_predicted_vs_true.append(mse_pred_true)
        
        # Calculate MSE: predicted vs spot mean
        mse_pred_mean = np.mean((predicted_spot_counts_valid - spot_mean) ** 2)
        mse_predicted_vs_spot_mean.append(mse_pred_mean)
        
        # Calculate MSE: true vs spot mean
        mse_true_mean = np.mean((true_spot_counts_valid - spot_mean) ** 2)
        mse_true_vs_spot_mean.append(mse_true_mean)
        
        valid_spots += 1
    
    # Calculate overall MSEs
    overall_mse_predicted_vs_true = np.mean(mse_predicted_vs_true)
    overall_mse_predicted_vs_spot_mean = np.mean(mse_predicted_vs_spot_mean)
    overall_mse_true_vs_spot_mean = np.mean(mse_true_vs_spot_mean)
    
    results = {
        'overall_mse_predicted_vs_true': overall_mse_predicted_vs_true,
        'overall_mse_predicted_vs_spot_mean': overall_mse_predicted_vs_spot_mean,
        'overall_mse_true_vs_spot_mean': overall_mse_true_vs_spot_mean,
        'spot_mse_predicted_vs_true': mse_predicted_vs_true,
        'spot_mse_predicted_vs_spot_mean': mse_predicted_vs_spot_mean,
        'spot_mse_true_vs_spot_mean': mse_true_vs_spot_mean,
        'valid_spots': valid_spots,
        'total_spots': n_spots
    }
    
    return results

# Run evaluation
print("Starting evaluation...")
results = evaluate_predictions(true_data, predicted_data, imgs, max_img_size=256)


Starting evaluation...
Evaluating 4358 spots...
Processing spot 0/4358
Processing spot 500/4358
Processing spot 1000/4358
Processing spot 1500/4358
Processing spot 2000/4358
Processing spot 2500/4358
Processing spot 3000/4358
Processing spot 3500/4358
Processing spot 4000/4358


In [29]:
# Display results
print("="*70)
print("EVALUATION RESULTS")
print("="*70)
print(f"Total spots processed: {results['total_spots']}")
print(f"Valid spots (with cells ≤256px): {results['valid_spots']}")
print(f"Valid spots percentage: {results['valid_spots']/results['total_spots']*100:.1f}%")
print()
print("OVERALL MSE COMPARISON:")
print(f"MSE (Predicted vs True):        {results['overall_mse_predicted_vs_true']:.6f}")
print(f"MSE (Predicted vs Spot Mean):   {results['overall_mse_predicted_vs_spot_mean']:.6f}")
print(f"MSE (True vs Spot Mean):        {results['overall_mse_true_vs_spot_mean']:.6f}")
print()
print("PERFORMANCE COMPARISON:")
print("1. Predicted vs Spot Mean:")
if results['overall_mse_predicted_vs_spot_mean'] < results['overall_mse_true_vs_spot_mean']:
    improvement = (results['overall_mse_true_vs_spot_mean'] - results['overall_mse_predicted_vs_spot_mean']) / results['overall_mse_true_vs_spot_mean'] * 100
    print(f"   ✅ Predictions are BETTER than true counts by {improvement:.1f}%")
else:
    degradation = (results['overall_mse_predicted_vs_spot_mean'] - results['overall_mse_true_vs_spot_mean']) / results['overall_mse_true_vs_spot_mean'] * 100
    print(f"   ❌ Predictions are WORSE than true counts by {degradation:.1f}%")

print("2. Predicted vs True:")
if results['overall_mse_predicted_vs_true'] < results['overall_mse_true_vs_spot_mean']:
    improvement = (results['overall_mse_true_vs_spot_mean'] - results['overall_mse_predicted_vs_true']) / results['overall_mse_true_vs_spot_mean'] * 100
    print(f"   ✅ Predictions are BETTER than spot mean by {improvement:.1f}%")
else:
    degradation = (results['overall_mse_predicted_vs_true'] - results['overall_mse_true_vs_spot_mean']) / results['overall_mse_true_vs_spot_mean'] * 100
    print(f"   ❌ Predictions are WORSE than spot mean by {degradation:.1f}%")

print()
print("DETAILED STATISTICS:")
print(f"Predicted vs True MSE - Mean:     {np.mean(results['spot_mse_predicted_vs_true']):.6f}")
print(f"Predicted vs True MSE - Std:      {np.std(results['spot_mse_predicted_vs_true']):.6f}")
print(f"Predicted vs Spot Mean MSE - Mean: {np.mean(results['spot_mse_predicted_vs_spot_mean']):.6f}")
print(f"Predicted vs Spot Mean MSE - Std:  {np.std(results['spot_mse_predicted_vs_spot_mean']):.6f}")
print(f"True vs Spot Mean MSE - Mean:     {np.mean(results['spot_mse_true_vs_spot_mean']):.6f}")
print(f"True vs Spot Mean MSE - Std:      {np.std(results['spot_mse_true_vs_spot_mean']):.6f}")


EVALUATION RESULTS
Total spots processed: 4358
Valid spots (with cells ≤256px): 4358
Valid spots percentage: 100.0%

OVERALL MSE COMPARISON:
MSE (Predicted vs True):        12.068824
MSE (Predicted vs Spot Mean):   1.995875
MSE (True vs Spot Mean):        9.996818

PERFORMANCE COMPARISON:
1. Predicted vs Spot Mean:
   ✅ Predictions are BETTER than true counts by 80.0%
2. Predicted vs True:
   ❌ Predictions are WORSE than spot mean by 20.7%

DETAILED STATISTICS:
Predicted vs True MSE - Mean:     12.068824
Predicted vs True MSE - Std:      8.758031
Predicted vs Spot Mean MSE - Mean: 1.995875
Predicted vs Spot Mean MSE - Std:  3.964895
True vs Spot Mean MSE - Mean:     9.996818
True vs Spot Mean MSE - Std:      6.221804
