# Synful Training Pipeline Visualization

This notebook provides a comprehensive visualization of the Synful training data pipeline. Each cell demonstrates a different stage of data processing, from raw input data to training-ready batches.

## Pipeline Overview

1. **Data Loading**: Load synapse locations from TSV files or MongoDB
2. **Volume Access**: Load electron microscopy volumes from Zarr files  
3. **Spatial Extraction**: Extract training cubes around synapse locations
4. **Mask Generation**: Create binary masks at synapse locations
5. **Direction Vectors**: Generate pre→post direction vectors for multitask learning
6. **Data Augmentation**: Apply 3D geometric and intensity augmentations
7. **Final Training Data**: Visualize the complete training batch

Let's explore each step in detail!

In [None]:
# Import Required Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# 3D visualization
from mpl_toolkits.mplot3d import Axes3D

# Set plotting style
plt.style.use('default')
sns.set_palette("husl")

# Configure matplotlib for better output
%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 8)
plt.rcParams['font.size'] = 12

print("✅ Successfully imported all required libraries")
print("📊 Ready for Synful pipeline visualization")

## Step 1: Load and Parse TSV Synapse Data

The first step is loading synapse location data. Synful supports multiple formats:
- **TSV files**: Simple tab-separated files with pre/post coordinates
- **MongoDB**: Production database with spatial indexing
- **Synthetic**: Generated data for testing

TSV format expected:
```
pre_x    pre_y    pre_z    post_x    post_y    post_z
1000.0   2000.0   500.0    1100.0    2100.0   510.0
1500.0   2500.0   600.0    1600.0    2600.0   610.0
...
```

In [None]:
# Create sample TSV data for demonstration
def create_sample_tsv_data(n_synapses=100, volume_size=(10000, 10000, 1000)):
    """Create sample synapse data for visualization"""
    np.random.seed(42)  # For reproducible results
    
    # Generate random pre-synapse locations
    pre_x = np.random.uniform(1000, volume_size[0]-1000, n_synapses)
    pre_y = np.random.uniform(1000, volume_size[1]-1000, n_synapses)
    pre_z = np.random.uniform(100, volume_size[2]-100, n_synapses)
    
    # Generate post-synapse locations nearby (synapses are typically close)
    post_x = pre_x + np.random.normal(0, 50, n_synapses)  # 50nm std
    post_y = pre_y + np.random.normal(0, 50, n_synapses)
    post_z = pre_z + np.random.normal(0, 20, n_synapses)  # smaller in z
    
    # Create DataFrame
    synapse_data = pd.DataFrame({
        'pre_x': pre_x,
        'pre_y': pre_y, 
        'pre_z': pre_z,
        'post_x': post_x,
        'post_y': post_y,
        'post_z': post_z
    })
    
    return synapse_data

# Load or create synapse data
synapse_df = create_sample_tsv_data(n_synapses=200)

print(f"📊 Loaded synapse data: {len(synapse_df)} synapses")
print(f"📐 Data shape: {synapse_df.shape}")
print("\\n🔍 First 5 synapses:")
print(synapse_df.head())

print("\\n📈 Data statistics:")
print(synapse_df.describe())

In [None]:
# Visualize raw synapse data distribution
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Pre-synapse locations
axes[0,0].scatter(synapse_df['pre_x'], synapse_df['pre_y'], 
                  c=synapse_df['pre_z'], cmap='viridis', alpha=0.7)
axes[0,0].set_xlabel('X (nm)')
axes[0,0].set_ylabel('Y (nm)')
axes[0,0].set_title('Pre-synapse Locations (XY, colored by Z)')
axes[0,0].grid(True, alpha=0.3)

# Post-synapse locations  
axes[0,1].scatter(synapse_df['post_x'], synapse_df['post_y'],
                  c=synapse_df['post_z'], cmap='viridis', alpha=0.7)
axes[0,1].set_xlabel('X (nm)')
axes[0,1].set_ylabel('Y (nm)')
axes[0,1].set_title('Post-synapse Locations (XY, colored by Z)')
axes[0,1].grid(True, alpha=0.3)

# Z-distribution
axes[1,0].hist([synapse_df['pre_z'], synapse_df['post_z']], 
               bins=20, alpha=0.7, label=['Pre', 'Post'])
axes[1,0].set_xlabel('Z coordinate (nm)')
axes[1,0].set_ylabel('Count')
axes[1,0].set_title('Z-coordinate Distribution')
axes[1,0].legend()
axes[1,0].grid(True, alpha=0.3)

# Coordinate ranges
coord_ranges = pd.DataFrame({
    'Min': [synapse_df[col].min() for col in synapse_df.columns],
    'Max': [synapse_df[col].max() for col in synapse_df.columns],
    'Range': [synapse_df[col].max() - synapse_df[col].min() for col in synapse_df.columns]
}, index=synapse_df.columns)

axes[1,1].table(cellText=coord_ranges.round(1).values,
                rowLabels=coord_ranges.index,
                colLabels=coord_ranges.columns,
                cellLoc='center',
                loc='center')
axes[1,1].set_title('Coordinate Ranges (nm)')
axes[1,1].axis('off')

plt.tight_layout()
plt.show()

print("✅ Synapse data visualization complete")

## Step 2: Extract Pre and Post-Synapse Coordinates

Now we'll separate the coordinates and visualize the 3D spatial distribution of synapses. This helps us understand:
- Spatial density of synapses
- Clustering patterns
- Volume coverage
- Pre/post relationship geometry

In [None]:
# Extract coordinate arrays (convert to Synful internal format: z,y,x)
pre_coords = synapse_df[['pre_z', 'pre_y', 'pre_x']].values  # z,y,x order
post_coords = synapse_df[['post_z', 'post_y', 'post_x']].values

print(f"📐 Pre-synapse coordinates shape: {pre_coords.shape}")
print(f"📐 Post-synapse coordinates shape: {post_coords.shape}")

# 3D visualization of synapse locations
fig = plt.figure(figsize=(16, 6))

# Plot 1: Pre-synapses
ax1 = fig.add_subplot(131, projection='3d')
scatter1 = ax1.scatter(pre_coords[:, 2], pre_coords[:, 1], pre_coords[:, 0], 
                       c='red', alpha=0.6, s=20, label='Pre-synapse')
ax1.set_xlabel('X (nm)')
ax1.set_ylabel('Y (nm)')
ax1.set_zlabel('Z (nm)')
ax1.set_title('Pre-synapse Locations')

# Plot 2: Post-synapses  
ax2 = fig.add_subplot(132, projection='3d')
scatter2 = ax2.scatter(post_coords[:, 2], post_coords[:, 1], post_coords[:, 0],
                       c='blue', alpha=0.6, s=20, label='Post-synapse')
ax2.set_xlabel('X (nm)')
ax2.set_ylabel('Y (nm)')
ax2.set_zlabel('Z (nm)')
ax2.set_title('Post-synapse Locations')

# Plot 3: Both together with connections
ax3 = fig.add_subplot(133, projection='3d')
ax3.scatter(pre_coords[:, 2], pre_coords[:, 1], pre_coords[:, 0],
            c='red', alpha=0.6, s=20, label='Pre')
ax3.scatter(post_coords[:, 2], post_coords[:, 1], post_coords[:, 0],
            c='blue', alpha=0.6, s=20, label='Post')

# Draw connections for first 50 synapses (to avoid clutter)
for i in range(min(50, len(pre_coords))):
    ax3.plot([pre_coords[i, 2], post_coords[i, 2]],
             [pre_coords[i, 1], post_coords[i, 1]],
             [pre_coords[i, 0], post_coords[i, 0]],
             'k-', alpha=0.3, linewidth=0.5)

ax3.set_xlabel('X (nm)')
ax3.set_ylabel('Y (nm)')
ax3.set_zlabel('Z (nm)')
ax3.set_title('Pre-Post Connections')
ax3.legend()

plt.tight_layout()
plt.show()

print(f"✅ Extracted coordinates for {len(pre_coords)} synapses")
print(f"📊 Coordinate format: Z,Y,X (Synful internal format)")
print(f"🔗 Visualized spatial distribution and connections")

## Step 3: Calculate Synapse Vector Features

This step computes important geometric features used in training:
- **Direction vectors**: Pre → Post displacement  
- **Distances**: Euclidean distance between partners
- **Orientations**: 3D direction angles
- **Spatial statistics**: For understanding synapse geometry

In [None]:
# Calculate direction vectors (pre -> post)
direction_vectors = post_coords - pre_coords

# Calculate distances
distances = np.linalg.norm(direction_vectors, axis=1)

# Calculate angles (elevation and azimuth)
# Elevation: angle from XY plane
elevation_angles = np.arctan2(direction_vectors[:, 0], 
                             np.sqrt(direction_vectors[:, 1]**2 + direction_vectors[:, 2]**2))
# Azimuth: angle in XY plane
azimuth_angles = np.arctan2(direction_vectors[:, 1], direction_vectors[:, 2])

# Convert to degrees
elevation_deg = np.degrees(elevation_angles)
azimuth_deg = np.degrees(azimuth_angles)

print(f"📊 Vector Features Summary:")
print(f"   Distance range: {distances.min():.1f} - {distances.max():.1f} nm")
print(f"   Mean distance: {distances.mean():.1f} ± {distances.std():.1f} nm")
print(f"   Direction vector shape: {direction_vectors.shape}")

# Visualize vector features
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Distance distribution
axes[0,0].hist(distances, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
axes[0,0].set_xlabel('Distance (nm)')
axes[0,0].set_ylabel('Count')
axes[0,0].set_title('Synapse Partner Distances')
axes[0,0].grid(True, alpha=0.3)

# Direction vector components
for i, axis_name in enumerate(['Z', 'Y', 'X']):
    axes[0,1].hist(direction_vectors[:, i], bins=30, alpha=0.5, 
                   label=f'{axis_name} component')
axes[0,1].set_xlabel('Vector component (nm)')
axes[0,1].set_ylabel('Count')
axes[0,1].set_title('Direction Vector Components')
axes[0,1].legend()
axes[0,1].grid(True, alpha=0.3)

# Angle distributions
axes[0,2].hist(elevation_deg, bins=30, alpha=0.7, color='orange')
axes[0,2].set_xlabel('Elevation angle (degrees)')
axes[0,2].set_ylabel('Count')
axes[0,2].set_title('Elevation Angles (from XY plane)')
axes[0,2].grid(True, alpha=0.3)

# Distance vs vector magnitude
axes[1,0].scatter(distances, np.linalg.norm(direction_vectors, axis=1), alpha=0.6)
axes[1,0].plot([0, distances.max()], [0, distances.max()], 'r--', alpha=0.5)
axes[1,0].set_xlabel('Calculated Distance (nm)')
axes[1,0].set_ylabel('Vector Magnitude (nm)')
axes[1,0].set_title('Distance Validation')
axes[1,0].grid(True, alpha=0.3)

# Azimuth distribution (polar plot)
ax_polar = plt.subplot(2, 3, 5, projection='polar')
ax_polar.hist(azimuth_angles, bins=20, alpha=0.7)
ax_polar.set_title('Azimuth Angle Distribution')

# 3D direction vectors (quiver plot sample)
ax_3d = fig.add_subplot(236, projection='3d')
# Show subset to avoid clutter
subset_idx = np.random.choice(len(pre_coords), size=20, replace=False)
ax_3d.quiver(pre_coords[subset_idx, 2], pre_coords[subset_idx, 1], pre_coords[subset_idx, 0],
             direction_vectors[subset_idx, 2], direction_vectors[subset_idx, 1], direction_vectors[subset_idx, 0],
             length=0.5, normalize=True, alpha=0.7)
ax_3d.set_xlabel('X (nm)')
ax_3d.set_ylabel('Y (nm)')
ax_3d.set_zlabel('Z (nm)')
ax_3d.set_title('Direction Vectors (sample)')

plt.tight_layout()
plt.show()

print("✅ Vector features calculated and visualized")

## Step 4: Simulate Volume Data and Training Cube Extraction

Now we simulate loading from a large Zarr volume and extracting training cubes around synapse locations. This demonstrates:
- Volume data structure and properties
- Spatial cube extraction
- Training data organization
- Memory-efficient sampling

In [None]:
# Simulate volume data and cube extraction
def simulate_em_volume_cube(center_coords, cube_size=(42, 430, 430), voxel_size=(40, 4, 4)):
    """Simulate extracting a training cube from EM volume"""
    
    # Convert physical coordinates to voxel coordinates
    voxel_coords = center_coords / np.array(voxel_size)
    voxel_coords = voxel_coords.astype(int)
    
    # Create synthetic EM-like data
    cube = np.random.randn(*cube_size).astype(np.float32)
    
    # Add some structure to make it look more EM-like
    # Add some membrane-like features
    for _ in range(5):
        membrane_z = np.random.randint(5, cube_size[0]-5)
        membrane_thickness = np.random.randint(1, 3)
        cube[membrane_z:membrane_z+membrane_thickness, :, :] += np.random.uniform(1, 2)
    
    # Add noise and normalize
    cube += np.random.normal(0, 0.1, cube.shape)
    cube = (cube - cube.mean()) / cube.std()
    
    return cube, voxel_coords

# Configuration matching training parameters
cube_size = (42, 430, 430)  # z, y, x in voxels
voxel_size = (40, 4, 4)     # z, y, x in nm
physical_cube_size = np.array(cube_size) * np.array(voxel_size)

print(f"📐 Training cube configuration:")
print(f"   Cube size (voxels): {cube_size}")
print(f"   Voxel size (nm): {voxel_size}")
print(f"   Physical size (nm): {physical_cube_size}")
print(f"   Total volume: {np.prod(cube_size):,} voxels")

# Extract training cubes for first few synapses
n_samples = 3
sample_cubes = []
sample_coords = []

for i in range(n_samples):
    # Use middle point between pre and post as cube center
    center_coord = (pre_coords[i] + post_coords[i]) / 2
    
    cube, voxel_coord = simulate_em_volume_cube(center_coord, cube_size, voxel_size)
    sample_cubes.append(cube)
    sample_coords.append(voxel_coord)
    
    print(f"Synapse {i+1}: Center at {center_coord} nm -> {voxel_coord} voxels")

# Visualize extracted cubes
fig, axes = plt.subplots(n_samples, 3, figsize=(15, n_samples*4))
if n_samples == 1:
    axes = axes.reshape(1, -1)

for i in range(n_samples):
    cube = sample_cubes[i]
    
    # Show different slices through the cube
    mid_z = cube.shape[0] // 2
    mid_y = cube.shape[1] // 2
    mid_x = cube.shape[2] // 2
    
    # XY slice (middle Z)
    im1 = axes[i, 0].imshow(cube[mid_z, :, :], cmap='gray', aspect='auto')
    axes[i, 0].set_title(f'Synapse {i+1}: XY slice (Z={mid_z})')
    axes[i, 0].set_xlabel('X (voxels)')
    axes[i, 0].set_ylabel('Y (voxels)')
    
    # XZ slice (middle Y)
    im2 = axes[i, 1].imshow(cube[:, mid_y, :], cmap='gray', aspect='auto')
    axes[i, 1].set_title(f'XZ slice (Y={mid_y})')
    axes[i, 1].set_xlabel('X (voxels)')
    axes[i, 1].set_ylabel('Z (voxels)')
    
    # YZ slice (middle X)
    im3 = axes[i, 2].imshow(cube[:, :, mid_x], cmap='gray', aspect='auto')
    axes[i, 2].set_title(f'YZ slice (X={mid_x})')
    axes[i, 2].set_xlabel('Y (voxels)')
    axes[i, 2].set_ylabel('Z (voxels)')

plt.tight_layout()
plt.show()

print(f"✅ Extracted {n_samples} training cubes")
print(f"📊 Cube statistics:")
for i, cube in enumerate(sample_cubes):
    print(f"   Cube {i+1}: {cube.shape}, range: {cube.min():.3f} to {cube.max():.3f}")

## Step 5: Generate Training Masks and Direction Vectors

This is the core of Synful training data preparation:
- **Binary masks**: Mark synapse locations in the volume
- **Direction vectors**: Encode pre→post direction at each voxel
- **Multitask outputs**: Both detection (mask) and direction prediction

In [None]:
# Generate training masks and direction vectors
def create_synapse_mask_and_vectors(cube_shape, synapse_coords, direction_vector, 
                                   blob_radius=10, d_blob_radius=100, voxel_size=(40, 4, 4)):
    """Create training mask and direction vectors for a cube"""
    
    # Initialize outputs
    mask = np.zeros(cube_shape, dtype=np.float32)
    direction_map = np.zeros((3,) + cube_shape, dtype=np.float32)  # 3D vectors
    
    # Convert synapse coordinates to cube coordinates
    cube_center = np.array(cube_shape) // 2
    
    # For this demo, place synapse at cube center
    # In real training, this would be calculated from actual coordinates
    synapse_voxel = cube_center.astype(int)
    
    # Create spherical mask around synapse location
    z_indices, y_indices, x_indices = np.mgrid[0:cube_shape[0], 0:cube_shape[1], 0:cube_shape[2]]
    
    # Calculate distance from synapse center
    dz = (z_indices - synapse_voxel[0]) * voxel_size[0]
    dy = (y_indices - synapse_voxel[1]) * voxel_size[1] 
    dx = (x_indices - synapse_voxel[2]) * voxel_size[2]
    
    distance = np.sqrt(dz**2 + dy**2 + dx**2)
    
    # Create binary mask within blob_radius
    mask[distance <= blob_radius] = 1.0
    
    # Create direction vectors within d_blob_radius
    direction_mask = distance <= d_blob_radius
    
    # Normalize direction vector
    norm_direction = direction_vector / (np.linalg.norm(direction_vector) + 1e-8)
    
    # Set direction vectors where mask is active
    for i in range(3):
        direction_map[i][direction_mask] = norm_direction[i]
    
    return mask, direction_map

# Generate masks and direction vectors for sample cubes
sample_masks = []
sample_directions = []

for i in range(n_samples):
    # Get synapse direction vector
    direction = direction_vectors[i]
    
    # Create mask and direction map
    mask, direction_map = create_synapse_mask_and_vectors(
        cube_size, sample_coords[i], direction
    )
    
    sample_masks.append(mask)
    sample_directions.append(direction_map)
    
    print(f"Synapse {i+1}: Mask has {mask.sum():.0f} positive voxels, "
          f"direction magnitude: {np.linalg.norm(direction):.2f}")

# Visualize masks and direction vectors
fig, axes = plt.subplots(n_samples, 4, figsize=(20, n_samples*4))
if n_samples == 1:
    axes = axes.reshape(1, -1)

for i in range(n_samples):
    cube = sample_cubes[i]
    mask = sample_masks[i]
    direction_map = sample_directions[i]
    
    mid_z = cube.shape[0] // 2
    
    # Original volume slice
    axes[i, 0].imshow(cube[mid_z, :, :], cmap='gray', alpha=0.8)
    axes[i, 0].set_title(f'Synapse {i+1}: Raw Volume')
    axes[i, 0].set_xlabel('X (voxels)')
    axes[i, 0].set_ylabel('Y (voxels)')
    
    # Mask overlay
    axes[i, 1].imshow(cube[mid_z, :, :], cmap='gray', alpha=0.7)
    mask_slice = mask[mid_z, :, :]
    axes[i, 1].contour(mask_slice, levels=[0.5], colors='red', linewidths=2)
    axes[i, 1].set_title('Synapse Mask (red contour)')
    axes[i, 1].set_xlabel('X (voxels)')
    axes[i, 1].set_ylabel('Y (voxels)')
    
    # Direction vector magnitude
    direction_magnitude = np.linalg.norm(direction_map, axis=0)
    im3 = axes[i, 2].imshow(direction_magnitude[mid_z, :, :], cmap='viridis')
    axes[i, 2].set_title('Direction Vector Magnitude')
    axes[i, 2].set_xlabel('X (voxels)')
    axes[i, 2].set_ylabel('Y (voxels)')
    plt.colorbar(im3, ax=axes[i, 2], fraction=0.046)
    
    # Direction vector field (subsample for visibility)
    step = 20
    y_sub = slice(None, None, step)
    x_sub = slice(None, None, step)
    
    Y, X = np.meshgrid(np.arange(0, cube_size[1], step), 
                       np.arange(0, cube_size[2], step), indexing='ij')
    
    dy_field = direction_map[1, mid_z, y_sub, x_sub]
    dx_field = direction_map[2, mid_z, y_sub, x_sub]
    
    axes[i, 3].imshow(cube[mid_z, :, :], cmap='gray', alpha=0.5)
    axes[i, 3].quiver(X, Y, dx_field, dy_field, 
                     scale=10, scale_units='xy', alpha=0.8, color='red')
    axes[i, 3].set_title('Direction Vector Field')
    axes[i, 3].set_xlabel('X (voxels)')
    axes[i, 3].set_ylabel('Y (voxels)')

plt.tight_layout()
plt.show()

# Statistics
print("\\n📊 Training Target Statistics:")
for i in range(n_samples):
    mask = sample_masks[i]
    direction_map = sample_directions[i]
    
    mask_ratio = mask.sum() / mask.size * 100
    direction_magnitude = np.linalg.norm(direction_map, axis=0)
    direction_active = (direction_magnitude > 0).sum()
    
    print(f"   Synapse {i+1}:")
    print(f"     Mask coverage: {mask_ratio:.3f}% of volume")
    print(f"     Direction active voxels: {direction_active:,}")
    print(f"     Mean direction magnitude: {direction_magnitude.mean():.3f}")

print("\\n✅ Training masks and direction vectors generated")

## Step 6: Apply Data Augmentations

Data augmentation is crucial for training robust models. Synful uses sophisticated 3D augmentations:
- **Geometric**: Rotations, flips, elastic deformation
- **Intensity**: Scaling, shifting, gamma correction  
- **Noise**: Gaussian noise, salt-and-pepper artifacts

Let's see how these affect our training data:

In [None]:
# Simulate the augmentation pipeline
def apply_intensity_augmentation(volume, scale_range=(0.8, 1.2), shift_range=(-0.2, 0.2)):
    """Apply intensity augmentations"""
    augmented = volume.copy()
    
    # Random intensity scaling
    scale = np.random.uniform(*scale_range)
    augmented = augmented * scale
    
    # Random intensity shift
    shift = np.random.uniform(*shift_range)
    augmented = augmented + shift
    
    # Random gamma correction (30% chance)
    if np.random.random() < 0.3:
        gamma = np.random.uniform(0.8, 1.2)
        augmented = np.sign(augmented) * np.power(np.abs(augmented), gamma)
    
    return augmented

def apply_noise_augmentation(volume, noise_std=0.1):
    """Apply noise augmentations"""
    augmented = volume.copy()
    
    # Gaussian noise
    noise = np.random.normal(0, noise_std, volume.shape)
    augmented = augmented + noise
    
    # Salt and pepper noise (10% chance)
    if np.random.random() < 0.1:
        mask = np.random.random(volume.shape) < 0.01
        augmented[mask] = np.random.uniform(-1, 1, mask.sum())
    
    return augmented

def apply_geometric_augmentation(volume, mask, direction_map):
    """Apply geometric augmentations (simplified for demo)"""
    # For demonstration, we'll just flip along one axis
    if np.random.random() < 0.5:
        # Flip Y axis
        volume = np.flip(volume, axis=1)
        mask = np.flip(mask, axis=1)
        direction_map = np.flip(direction_map, axis=2)  # axis 2 corresponds to Y in direction map
        direction_map[1] *= -1  # Flip Y component of direction vectors
    
    return volume, mask, direction_map

# Apply augmentations to first sample
original_cube = sample_cubes[0].copy()
original_mask = sample_masks[0].copy()
original_direction = sample_directions[0].copy()

# Create multiple augmented versions
n_augmentations = 4
augmented_data = []

np.random.seed(123)  # For reproducible demo
for aug_idx in range(n_augmentations):
    # Start with original
    aug_cube = original_cube.copy()
    aug_mask = original_mask.copy()
    aug_direction = original_direction.copy()
    
    # Apply different augmentation combinations
    if aug_idx == 0:
        # Original (no augmentation)
        pass
    elif aug_idx == 1:
        # Intensity only
        aug_cube = apply_intensity_augmentation(aug_cube)
    elif aug_idx == 2:
        # Noise only
        aug_cube = apply_noise_augmentation(aug_cube)
    elif aug_idx == 3:
        # Geometric + intensity + noise
        aug_cube, aug_mask, aug_direction = apply_geometric_augmentation(aug_cube, aug_mask, aug_direction)
        aug_cube = apply_intensity_augmentation(aug_cube)
        aug_cube = apply_noise_augmentation(aug_cube)
    
    augmented_data.append((aug_cube, aug_mask, aug_direction))

# Visualize augmentations
fig, axes = plt.subplots(n_augmentations, 4, figsize=(20, n_augmentations*4))
aug_names = ['Original', 'Intensity Aug', 'Noise Aug', 'Combined Aug']

for aug_idx in range(n_augmentations):
    cube, mask, direction_map = augmented_data[aug_idx]
    mid_z = cube.shape[0] // 2
    
    # Raw volume
    im1 = axes[aug_idx, 0].imshow(cube[mid_z, :, :], cmap='gray', vmin=-2, vmax=2)
    axes[aug_idx, 0].set_title(f'{aug_names[aug_idx]}: Raw Volume')
    axes[aug_idx, 0].set_ylabel('Y (voxels)')
    
    # Volume with mask overlay
    axes[aug_idx, 1].imshow(cube[mid_z, :, :], cmap='gray', alpha=0.7, vmin=-2, vmax=2)
    mask_slice = mask[mid_z, :, :]
    axes[aug_idx, 1].contour(mask_slice, levels=[0.5], colors='red', linewidths=2)
    axes[aug_idx, 1].set_title('Volume + Mask')
    
    # Direction magnitude
    direction_magnitude = np.linalg.norm(direction_map, axis=0)
    im3 = axes[aug_idx, 2].imshow(direction_magnitude[mid_z, :, :], cmap='viridis')
    axes[aug_idx, 2].set_title('Direction Magnitude')
    
    # Intensity histogram
    axes[aug_idx, 3].hist(cube.flatten(), bins=50, alpha=0.7, density=True)
    axes[aug_idx, 3].set_xlabel('Intensity')\n    axes[aug_idx, 3].set_ylabel('Density')
    axes[aug_idx, 3].set_title('Intensity Distribution')
    axes[aug_idx, 3].grid(True, alpha=0.3)
    
    # Add statistics text
    stats_text = f'Mean: {cube.mean():.2f}\\nStd: {cube.std():.2f}\\nRange: [{cube.min():.2f}, {cube.max():.2f}]'
    axes[aug_idx, 3].text(0.05, 0.95, stats_text, transform=axes[aug_idx, 3].transAxes, 
                         verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

# Add x-labels to bottom row
for col in range(4):
    axes[-1, col].set_xlabel('X (voxels)' if col < 3 else 'Intensity')

plt.tight_layout()
plt.show()

# Quantify augmentation effects
print("📊 Augmentation Effects Summary:")
for aug_idx, (cube, mask, direction_map) in enumerate(augmented_data):
    print(f"   {aug_names[aug_idx]}:")
    print(f"     Volume: mean={cube.mean():.3f}, std={cube.std():.3f}")
    print(f"     Mask sum: {mask.sum():.0f} voxels")
    print(f"     Direction active: {(np.linalg.norm(direction_map, axis=0) > 0).sum():,} voxels")

print("\\n✅ Data augmentations demonstrated")

## Step 7: Final Training Batch Visualization

Now let's see what the final training batch looks like after all preprocessing steps. This demonstrates:
- **Batch structure**: How data is organized for PyTorch training
- **Tensor shapes**: Input and target dimensions
- **Data ranges**: Normalized values ready for neural network
- **Quality checks**: Ensuring data integrity

In [None]:
# Create final training batch
def create_training_batch(augmented_data, batch_size=None):
    """Create a training batch from augmented data"""
    if batch_size is None:
        batch_size = len(augmented_data)
    
    # Initialize batch tensors
    batch_raw = []
    batch_mask = []
    batch_direction = []
    
    for i in range(min(batch_size, len(augmented_data))):
        cube, mask, direction_map = augmented_data[i]
        
        # Add channel dimension and normalize
        raw_normalized = (cube - cube.mean()) / (cube.std() + 1e-8)
        raw_with_channel = raw_normalized[np.newaxis, ...]  # Add channel dim
        
        # Prepare mask (add channel dim)
        mask_with_channel = mask[np.newaxis, ...]
        
        # Direction map is already (3, z, y, x)
        
        batch_raw.append(raw_with_channel)
        batch_mask.append(mask_with_channel)
        batch_direction.append(direction_map)
    
    # Stack into batch tensors
    batch_raw = np.stack(batch_raw, axis=0)  # (B, 1, Z, Y, X)
    batch_mask = np.stack(batch_mask, axis=0)  # (B, 1, Z, Y, X)
    batch_direction = np.stack(batch_direction, axis=0)  # (B, 3, Z, Y, X)
    
    return batch_raw, batch_mask, batch_direction

# Create training batch
batch_raw, batch_mask, batch_direction = create_training_batch(augmented_data)

print("🚀 Final Training Batch:")
print(f"   Raw data: {batch_raw.shape} (B, C, Z, Y, X)")
print(f"   Masks: {batch_mask.shape} (B, C, Z, Y, X)")  
print(f"   Directions: {batch_direction.shape} (B, 3, Z, Y, X)")
print(f"   Raw data range: [{batch_raw.min():.3f}, {batch_raw.max():.3f}]")
print(f"   Mask range: [{batch_mask.min():.1f}, {batch_mask.max():.1f}]")
print(f"   Direction range: [{batch_direction.min():.3f}, {batch_direction.max():.3f}]")

# Visualize final batch
fig, axes = plt.subplots(3, batch_raw.shape[0], figsize=(4*batch_raw.shape[0], 12))
if batch_raw.shape[0] == 1:
    axes = axes.reshape(-1, 1)

for b in range(batch_raw.shape[0]):
    mid_z = batch_raw.shape[2] // 2
    
    # Raw data
    im1 = axes[0, b].imshow(batch_raw[b, 0, mid_z, :, :], cmap='gray')
    axes[0, b].set_title(f'Batch {b}: Raw Input')
    axes[0, b].set_xlabel('X')
    axes[0, b].set_ylabel('Y')
    plt.colorbar(im1, ax=axes[0, b], fraction=0.046)
    
    # Mask target
    im2 = axes[1, b].imshow(batch_mask[b, 0, mid_z, :, :], cmap='Reds', vmin=0, vmax=1)
    axes[1, b].set_title(f'Mask Target')
    axes[1, b].set_xlabel('X')
    axes[1, b].set_ylabel('Y')
    plt.colorbar(im2, ax=axes[1, b], fraction=0.046)
    
    # Direction magnitude target
    direction_mag = np.linalg.norm(batch_direction[b], axis=0)
    im3 = axes[2, b].imshow(direction_mag[mid_z, :, :], cmap='viridis')
    axes[2, b].set_title(f'Direction Target')
    axes[2, b].set_xlabel('X')
    axes[2, b].set_ylabel('Y')
    plt.colorbar(im3, ax=axes[2, b], fraction=0.046)

plt.tight_layout()
plt.show()

# Data quality checks
print("\\n🔍 Data Quality Checks:")

# Check for NaN or inf values
has_nan_raw = np.isnan(batch_raw).any()
has_inf_raw = np.isinf(batch_raw).any()
has_nan_mask = np.isnan(batch_mask).any()
has_nan_direction = np.isnan(batch_direction).any()

print(f"   Raw data: NaN={has_nan_raw}, Inf={has_inf_raw}")
print(f"   Masks: NaN={has_nan_mask}")
print(f"   Directions: NaN={has_nan_direction}")

# Check value ranges
print(f"   Raw normalization: mean={batch_raw.mean():.3f}, std={batch_raw.std():.3f}")
print(f"   Mask values: unique={np.unique(batch_mask)}")

# Check spatial consistency
for b in range(batch_raw.shape[0]):
    mask_volume = batch_mask[b, 0].sum()
    direction_active = (np.linalg.norm(batch_direction[b], axis=0) > 0).sum()
    print(f"   Batch {b}: mask_volume={mask_volume:.0f}, direction_active={direction_active}")

print("\\n✅ Training batch ready for neural network!")

# Memory usage estimation
raw_memory_mb = batch_raw.nbytes / (1024**2)
mask_memory_mb = batch_mask.nbytes / (1024**2)
direction_memory_mb = batch_direction.nbytes / (1024**2)
total_memory_mb = raw_memory_mb + mask_memory_mb + direction_memory_mb

print(f"\\n💾 Memory Usage:")
print(f"   Raw data: {raw_memory_mb:.1f} MB")
print(f"   Masks: {mask_memory_mb:.1f} MB")
print(f"   Directions: {direction_memory_mb:.1f} MB")
print(f"   Total per batch: {total_memory_mb:.1f} MB")

## Summary: Complete Training Pipeline Visualization

🎉 **Congratulations!** You've now seen every step of the Synful training data pipeline:

### Pipeline Steps Completed:
1. ✅ **TSV Data Loading**: Parsed pre/post synapse coordinates
2. ✅ **Spatial Analysis**: Analyzed 3D synapse distributions
3. ✅ **Vector Features**: Calculated direction vectors and distances  
4. ✅ **Volume Simulation**: Demonstrated zarr cube extraction
5. ✅ **Mask Generation**: Created binary detection targets
6. ✅ **Direction Mapping**: Generated 3D direction vector targets
7. ✅ **Data Augmentation**: Applied geometric, intensity, and noise augmentations
8. ✅ **Final Batching**: Prepared training-ready tensors

### Key Insights:
- 📊 **Data Structure**: TSV → coordinates → training cubes → neural network tensors
- 🎯 **Multitask Learning**: Both synapse detection (masks) and direction prediction (vectors)
- 🔄 **Augmentation**: Robust training through diverse data variations
- 💾 **Memory Efficiency**: Optimized tensor shapes for GPU training
- 🔍 **Quality Control**: Validation at each pipeline stage

### Next Steps:
- Use this pipeline with real zarr volumes and TSV/MongoDB synapse data
- Adjust parameters (cube_size, blob_radius, augmentation strength) as needed
- Monitor training performance and data quality
- Scale to production with larger datasets

This notebook provides the foundation for understanding and debugging the complete Synful training data flow!