In [None]:
pip install tensorflow numpy matplotlib

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
import os

# ==================== Configuration ====================
CONFIG = {
    'num_samples': 500,
    'image_size': (64, 64),
    'base_filters': 32,
    'depth': 3,
    'epochs': 60,
    'batch_size': 32,
    'learning_rate': 0.001,
}

# ==================== Data Generation ====================
def create_ring_mask(size=64, inner_radius=0.3, outer_radius=0.5):
    """Generate ring mask: ring=0.9 (dark), background=0.05-0.15 (light)"""
    y, x = np.ogrid[-1:1:size*1j, -1:1:size*1j]
    distance = np.sqrt(x**2 + y**2)
    
    # Ring region (inner_radius < distance <= outer_radius) = 0.9
    ring = (distance > inner_radius) & (distance <= outer_radius)
    
    # Background region = light color (0.05-0.15)
    background = np.random.uniform(0.05, 0.15, (size, size))
    
    # Combine: ring=0.9 (dark), background=light
    mask = np.where(ring, 0.9, background)
    return mask

def generate_synthetic_data(num_samples=500, image_size=(64, 64)):
    """Generate ring dataset"""
    X = np.array([create_ring_mask(image_size[0]) for _ in range(num_samples)])
    y = np.array([create_ring_mask(image_size[0]) for _ in range(num_samples)])
    
    # Data augmentation: rotation
    X = np.array([np.rot90(x, np.random.randint(0, 4)) for x in X])
    y = np.array([np.rot90(y_, np.random.randint(0, 4)) for y_ in y])
    
    # Add Gaussian noise to simulate measurement error
    X += np.random.normal(0, 0.03, X.shape)
    X = np.clip(X, 0, 1)
    
    # Normalization
    X = (X - X.min()) / (X.max() - X.min() + 1e-8)
    y = (y - y.min()) / (y.max() - y.min() + 1e-8)
    return X, y

def postprocess_prediction(pred, threshold=0.5, scale=10):
    """
    Enhance prediction results using Sigmoid function for smooth region separation
    
    Parameters:
    - pred: raw prediction (0-1 range)
    - threshold: sigmoid center threshold (default 0.5)
    - scale: sigmoid steepness (larger value = steeper, better separation)
    """
    # Sigmoid function: f(x) = 1 / (1 + exp(-scale * (x - threshold)))
    # Dark regions (>threshold) enhanced to ~1.0, light regions (<threshold) weakened to ~0.0
    result = 1.0 / (1.0 + np.exp(-scale * (pred - threshold)))
    return result

# ==================== Model Building ====================
def build_unet(input_shape=(64, 64, 1), base_filters=32, depth=3):
    """Lightweight U-Net architecture"""
    inputs = layers.Input(shape=input_shape)
    x = inputs
    skips = []
    
    # Encoder
    for i in range(depth):
        x = layers.Conv2D(base_filters * 2**i, (3, 3), activation='relu', padding='same')(x)
        x = layers.Conv2D(base_filters * 2**i, (3, 3), activation='relu', padding='same')(x)
        skips.append(x)
        x = layers.MaxPooling2D((2, 2))(x)
    
    # Bottleneck
    x = layers.Conv2D(base_filters * 2**depth, (3, 3), activation='relu', padding='same')(x)
    x = layers.Conv2D(base_filters * 2**depth, (3, 3), activation='relu', padding='same')(x)
    
    # Decoder
    for i in reversed(range(depth)):
        x = layers.UpSampling2D((2, 2))(x)
        x = layers.Concatenate()([x, skips[i]])
        x = layers.Conv2D(base_filters * 2**i, (3, 3), activation='relu', padding='same')(x)
        x = layers.Conv2D(base_filters * 2**i, (3, 3), activation='relu', padding='same')(x)
    
    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(x)
    model = models.Model(inputs, outputs)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=CONFIG['learning_rate']),
                  loss='binary_crossentropy', metrics=['mse', 'mae'])
    return model

# ==================== Main Program ====================
# Generate data
X, y = generate_synthetic_data(CONFIG['num_samples'], CONFIG['image_size'])
X = X[..., np.newaxis]
y = y[..., np.newaxis]

# Create results directory
result_dir = os.path.join('results')
os.makedirs(result_dir, exist_ok=True)

# Train model
print("Training U-Net...")
model = build_unet(input_shape=(64, 64, 1), base_filters=CONFIG['base_filters'], depth=CONFIG['depth'])
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
history = model.fit(X, y, epochs=CONFIG['epochs'], batch_size=CONFIG['batch_size'],
                   validation_split=0.2, callbacks=[early_stopping], verbose=1)

val_mse = history.history['val_mse'][-1]
val_mae = history.history['val_mae'][-1]
print(f"Final validation MSE: {val_mse:.4f}, MAE: {val_mae:.4f}")

# Save loss curve
plt.figure(figsize=(10, 5))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('U-Net Training Curve (Ring Detection)')
plt.savefig(os.path.join(result_dir, "training_loss.png"), dpi=100)
plt.close()

# Predict and visualize
idx = np.random.randint(0, X.shape[0])
sample_input = X[idx:idx+1]
sample_true = y[idx]
sample_pred = model.predict(sample_input)[0]
sample_pred_enhanced = postprocess_prediction(sample_pred, threshold=0.5, scale=10)

plt.figure(figsize=(16, 4))
plt.subplot(1, 4, 1)
plt.title('Input (Ring with Noise)')
plt.imshow(1-sample_input[0, :, :, 0], cmap='gray')
plt.colorbar()

plt.subplot(1, 4, 2)
plt.title('True Label')
plt.imshow(1-sample_true[:, :, 0], cmap='gray')
plt.colorbar()

plt.subplot(1, 4, 3)
plt.title('Raw Prediction')
plt.imshow(1-sample_pred[:, :, 0], cmap='gray')
plt.colorbar()

plt.subplot(1, 4, 4)
plt.title('Enhanced Prediction (Sigmoid)')
plt.imshow(1-sample_pred_enhanced[:, :, 0], cmap='gray')
plt.colorbar()

plt.tight_layout()
plt.savefig(os.path.join(result_dir, "prediction_comparison.png"), dpi=100)
plt.close()

# Save metrics
with open(os.path.join(result_dir, "metrics.txt"), "w") as f:
    f.write(f"Ring Detection U-Net with Sigmoid Enhancement\n")
    f.write(f"Val MSE: {val_mse:.4f}\n")
    f.write(f"Val MAE: {val_mae:.4f}\n")
    f.write(f"Samples: {CONFIG['num_samples']}, Epochs: {len(history.history['loss'])}\n")
    f.write(f"\nPostprocessing: Sigmoid function with threshold=0.5, scale=10\n")

print("Training completed! Results saved to: results/")