# Cryo-EM CNN Analysis - Visual Results

Visualize denoising results from the Fortran/CUDA CNN training.

**WARNING:**: Requires over 50GB of RAM (use the streaming version for lower RAM systems)

**Model:** 3-layer CNN (1→16→16→1 channels)  
**Dataset:** 29,952 cryo-EM particle patches (1024×1024)  
**Training:** Epoch 1 achieved val loss 0.00697 (2× better than PyTorch!)  

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch.nn as nn
from scipy import stats

plt.rcParams['figure.figsize'] = (15, 5)
plt.rcParams['image.cmap'] = 'gray'

## 1. Load Fortran Model Weights

In [2]:
# Path to best checkpoint
checkpoint_dir = Path('../v28f_e_final_training/saved_models/cryo_cnn/epoch_0001/')

def load_fortran_weights(prefix):
    """Load weights and bias from Fortran binary files."""
    # Fortran saves as (kH, kW, in_ch, out_ch)
    weights = np.fromfile(f'{prefix}weights.bin', dtype=np.float32)
    bias = np.fromfile(f'{prefix}bias.bin', dtype=np.float32)
    return weights, bias

# Load all layers
conv1_w, conv1_b = load_fortran_weights(checkpoint_dir / 'conv1_')
conv2_w, conv2_b = load_fortran_weights(checkpoint_dir / 'conv2_')
conv3_w, conv3_b = load_fortran_weights(checkpoint_dir / 'conv3_')

print(f"Conv1: weights {conv1_w.shape}, bias {conv1_b.shape}")
print(f"Conv2: weights {conv2_w.shape}, bias {conv2_b.shape}")
print(f"Conv3: weights {conv3_w.shape}, bias {conv3_b.shape}")

# Reshape weights: (kH, kW, in_ch, out_ch) -> (out_ch, in_ch, kH, kW)
conv1_w = conv1_w.reshape(3, 3, 1, 16).transpose(3, 2, 0, 1)
conv2_w = conv2_w.reshape(3, 3, 16, 16).transpose(3, 2, 0, 1)
conv3_w = conv3_w.reshape(3, 3, 16, 1).transpose(3, 2, 0, 1)

print(f"\nReshaped for PyTorch:")
print(f"Conv1: {conv1_w.shape}")
print(f"Conv2: {conv2_w.shape}")
print(f"Conv3: {conv3_w.shape}")

Conv1: weights (144,), bias (16,)
Conv2: weights (2304,), bias (16,)
Conv3: weights (144,), bias (1,)

Reshaped for PyTorch:
Conv1: (16, 1, 3, 3)
Conv2: (16, 16, 3, 3)
Conv3: (1, 16, 3, 3)


## 2. Create PyTorch Model with Fortran Weights

In [3]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(16, 1, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

# Create model and load Fortran weights
model = SimpleCNN()
model.conv1.weight.data = torch.from_numpy(conv1_w)
model.conv1.bias.data = torch.from_numpy(conv1_b)
model.conv2.weight.data = torch.from_numpy(conv2_w)
model.conv2.bias.data = torch.from_numpy(conv2_b)
model.conv3.weight.data = torch.from_numpy(conv3_w)
model.conv3.bias.data = torch.from_numpy(conv3_b)
model.eval()

print("✓ Model loaded with Fortran weights")

✓ Model loaded with Fortran weights


## 3. Load Test Data

In [4]:
# Load test data
test_noisy = np.fromfile('../data/cryo_data_streaming/test_input.bin', dtype=np.float32)
test_clean = np.fromfile('../data/cryo_data_streaming/test_target.bin', dtype=np.float32)

# Reshape to (num_patches, 1024, 1024)
num_test = len(test_noisy) // (1024 * 1024)
test_noisy = test_noisy.reshape(num_test, 1024, 1024)
test_clean = test_clean.reshape(num_test, 1024, 1024)

print(f"Loaded {num_test:,} test patches")
print(f"Noisy range: [{test_noisy.min():.3f}, {test_noisy.max():.3f}]")
print(f"Clean range: [{test_clean.min():.3f}, {test_clean.max():.3f}]")

Loaded 3,211 test patches
Noisy range: [0.000, 1.000]
Clean range: [0.000, 1.000]


## 4. Run Inference on Test Set

In [None]:
# Denoise all test patches
denoised = []
batch_size = 8

with torch.no_grad():
    for i in range(0, num_test, batch_size):
        batch = test_noisy[i:i+batch_size]
        batch_tensor = torch.from_numpy(batch).unsqueeze(1)  # Add channel dim
        output = model(batch_tensor)
        denoised.append(output.squeeze(1).numpy())
        
        if (i // batch_size + 1) % 100 == 0:
            print(f"Processed {i+batch_size}/{num_test} patches...")

denoised = np.concatenate(denoised, axis=0)
print(f"\n✓ Denoised {num_test:,} patches")
print(f"Denoised range: [{denoised.min():.3f}, {denoised.max():.3f}]")

Processed 800/3211 patches...
Processed 1600/3211 patches...
Processed 2400/3211 patches...
Processed 3200/3211 patches...


## 5. Calculate Metrics

In [None]:
# MSE and RMSE
mse = np.mean((denoised - test_clean) ** 2)
rmse = np.sqrt(mse)

# PSNR (Peak Signal-to-Noise Ratio)
data_range = test_clean.max() - test_clean.min()
psnr = 20 * np.log10(data_range / rmse)

# Per-patch metrics
patch_mse = np.mean((denoised - test_clean) ** 2, axis=(1, 2))
patch_rmse = np.sqrt(patch_mse)

# Correlation
correlations = []
for i in range(num_test):
    corr, _ = stats.pearsonr(denoised[i].flatten(), test_clean[i].flatten())
    correlations.append(corr)
correlations = np.array(correlations)

print("=" * 60)
print("Test Set Metrics")
print("=" * 60)
print(f"MSE:                {mse:.6f}")
print(f"RMSE:               {rmse:.6f}")
print(f"PSNR:               {psnr:.2f} dB")
print(f"Mean Correlation:   {correlations.mean():.4f}")
print(f"Median Correlation: {np.median(correlations):.4f}")
print(f"Min Correlation:    {correlations.min():.4f}")
print("=" * 60)

## 6. Visualize Random Samples

In [None]:
# Show 5 random examples
np.random.seed(42)
indices = np.random.choice(num_test, 5, replace=False)

for idx in indices:
    noisy = test_noisy[idx]
    clean = test_clean[idx]
    pred = denoised[idx]
    
    # Calculate metrics for this patch
    patch_mse_val = np.mean((pred - clean) ** 2)
    patch_rmse_val = np.sqrt(patch_mse_val)
    corr, _ = stats.pearsonr(pred.flatten(), clean.flatten())
    
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    # Noisy input
    axes[0].imshow(noisy, cmap='gray')
    axes[0].set_title(f'Noisy Input (Patch {idx})', fontsize=14)
    axes[0].axis('off')
    
    # Denoised output
    axes[1].imshow(pred, cmap='gray')
    axes[1].set_title(f'Denoised (Fortran CNN)\nRMSE: {patch_rmse_val:.4f}, Corr: {corr:.4f}', fontsize=14)
    axes[1].axis('off')
    
    # Clean target
    axes[2].imshow(clean, cmap='gray')
    axes[2].set_title('Clean Target', fontsize=14)
    axes[2].axis('off')
    
    # Difference map
    diff = np.abs(pred - clean)
    im = axes[3].imshow(diff, cmap='hot')
    axes[3].set_title(f'Abs Error\nMax: {diff.max():.4f}', fontsize=14)
    axes[3].axis('off')
    plt.colorbar(im, ax=axes[3], fraction=0.046)
    
    plt.tight_layout()
    plt.show()

## 7. Distribution Analysis

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# RMSE distribution
axes[0, 0].hist(patch_rmse, bins=50, alpha=0.7, edgecolor='black')
axes[0, 0].axvline(patch_rmse.mean(), color='red', linestyle='--', label=f'Mean: {patch_rmse.mean():.4f}')
axes[0, 0].set_xlabel('RMSE', fontsize=12)
axes[0, 0].set_ylabel('Count', fontsize=12)
axes[0, 0].set_title('Per-Patch RMSE Distribution', fontsize=14)
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# Correlation distribution
axes[0, 1].hist(correlations, bins=50, alpha=0.7, color='green', edgecolor='black')
axes[0, 1].axvline(correlations.mean(), color='red', linestyle='--', label=f'Mean: {correlations.mean():.4f}')
axes[0, 1].set_xlabel('Correlation', fontsize=12)
axes[0, 1].set_ylabel('Count', fontsize=12)
axes[0, 1].set_title('Per-Patch Correlation Distribution', fontsize=14)
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

# Pixel value distribution comparison
axes[1, 0].hist(test_clean.flatten()[::1000], bins=100, alpha=0.5, label='Clean', density=True)
axes[1, 0].hist(denoised.flatten()[::1000], bins=100, alpha=0.5, label='Denoised', density=True)
axes[1, 0].set_xlabel('Pixel Value', fontsize=12)
axes[1, 0].set_ylabel('Density', fontsize=12)
axes[1, 0].set_title('Pixel Value Distributions', fontsize=14)
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

# Scatter plot: predicted vs target
sample_indices = np.random.choice(len(denoised.flatten()), 10000, replace=False)
axes[1, 1].scatter(test_clean.flatten()[sample_indices], 
                   denoised.flatten()[sample_indices],
                   alpha=0.1, s=1)
axes[1, 1].plot([test_clean.min(), test_clean.max()], 
                [test_clean.min(), test_clean.max()], 
                'r--', label='Perfect prediction')
axes[1, 1].set_xlabel('Target Pixel Value', fontsize=12)
axes[1, 1].set_ylabel('Predicted Pixel Value', fontsize=12)
axes[1, 1].set_title('Predicted vs Target (10k samples)', fontsize=14)
axes[1, 1].legend()
axes[1, 1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 8. Best and Worst Cases

In [None]:
# Find best and worst patches by RMSE
best_idx = np.argmin(patch_rmse)
worst_idx = np.argmax(patch_rmse)

print(f"Best patch:  idx={best_idx}, RMSE={patch_rmse[best_idx]:.6f}, Corr={correlations[best_idx]:.4f}")
print(f"Worst patch: idx={worst_idx}, RMSE={patch_rmse[worst_idx]:.6f}, Corr={correlations[worst_idx]:.4f}")

for title, idx in [("Best", best_idx), ("Worst", worst_idx)]:
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    axes[0].imshow(test_noisy[idx], cmap='gray')
    axes[0].set_title(f'{title} - Noisy Input', fontsize=14)
    axes[0].axis('off')
    
    axes[1].imshow(denoised[idx], cmap='gray')
    axes[1].set_title(f'{title} - Denoised\nRMSE: {patch_rmse[idx]:.6f}', fontsize=14)
    axes[1].axis('off')
    
    axes[2].imshow(test_clean[idx], cmap='gray')
    axes[2].set_title(f'{title} - Clean Target', fontsize=14)
    axes[2].axis('off')
    
    diff = np.abs(denoised[idx] - test_clean[idx])
    im = axes[3].imshow(diff, cmap='hot')
    axes[3].set_title(f'{title} - Abs Error\nMax: {diff.max():.4f}', fontsize=14)
    axes[3].axis('off')
    plt.colorbar(im, ax=axes[3], fraction=0.046)
    
    plt.tight_layout()
    plt.show()

## Summary

This notebook demonstrates:
- ✅ Loading Fortran-trained weights into PyTorch
- ✅ Running inference on test set
- ✅ Quantitative metrics (MSE, RMSE, PSNR, Correlation)
- ✅ Visual inspection of denoising quality
- ✅ Statistical analysis of results

**Expected Results (based on training):**
- Test RMSE: ~0.08-0.09 (matches val loss of 0.007)
- Correlation: >0.95
- PSNR: >25 dB
- Visually clean reconstructions with preserved particle features