# SELF-SUPERVISED DENOISING: PART THREE (Using Entrypoints)
### Modernized Implementation

This notebook demonstrates how to use the entrypoint scripts to reproduce the functionality of Tutorial 3.


## Overview

This notebook uses the `train` and `infer` entrypoint scripts to:
1. Generate test data
2. Train a blind-trace denoising model
3. Run inference and visualize results

The entrypoint scripts provide a clean interface to the training and inference functionality, making it easy to reproduce results and experiment with different parameters.


In [None]:
# Import necessary packages
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import sys
from pathlib import Path
import torch

# Import our module functions for visualization and data handling
from blindspot_denoise.models import UNet
from blindspot_denoise.utils import add_trace_wise_noise, set_seed
from blindspot_denoise.preprocessing import multi_active_pixels

# Set plotting parameters
%matplotlib inline
cmap = 'seismic'
vmin = -0.5
vmax = 0.5

# Set seed for reproducibility
set_seed(42)


## Step 1: Generate Test Data

First, we'll generate a test dataset of random seismic-like events. Alternatively, you can use your own data file.


In [None]:
# Generate test data
test_data_path = "tests/test_data.npy"
print("Generating test dataset...")
result = subprocess.run(
    [
        sys.executable,
        "tests/generate_test_data.py",
        "--output", test_data_path,
        "--n-samples", "50",
        "--n-traces", "64",
        "--n-time-samples", "128",
        "--seed", "42"
    ],
    capture_output=True,
    text=True
)
print(result.stdout)
if result.stderr:
    print("Errors:", result.stderr)

# Load and visualize the generated data
data = np.load(test_data_path)
print(f"\nData shape: {data.shape}")


In [None]:
# Visualize a sample from the generated data
plt.figure(figsize=[7, 5])
plt.imshow(data[0], cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')
plt.title('Sample from Generated Test Data')
plt.xlabel('Trace')
plt.ylabel('Time Sample')
plt.colorbar()
plt.tight_layout()
plt.show()


## Step 2: Add Trace-wise Noise

Let's add trace-wise noise to the data to simulate the noisy input we want to denoise.


In [None]:
# Add trace-wise noise to the data
print("Adding trace-wise noise...")
noisy_patches = add_trace_wise_noise(
    data,
    num_noisy_traces=5,
    noisy_trace_value=0.0,
    num_realisations=7,
)

# Randomize patch order
shuffler = np.random.permutation(len(noisy_patches))
noisy_patches = noisy_patches[shuffler]

print(f"Noisy patches shape: {noisy_patches.shape}")


In [None]:
# Visualize some noisy patches
fig, axs = plt.subplots(3, 6, figsize=[25, 17])
for i in range(6 * 3):
    axs.ravel()[i].imshow(noisy_patches[i], cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')
    axs.ravel()[i].set_title(f'Patch {i}')
fig.tight_layout()
plt.show()


## Step 3: Visualize Preprocessing (Active Pixel Corruption)

Let's visualize what happens during the preprocessing step where active pixels are corrupted.


In [None]:
# Visualize the corruption process
crpt_patch, mask = multi_active_pixels(
    noisy_patches[0],
    active_number=3,
    noise_level=0.5
)

# Use the pre-made plotting function to visualise the corruption
fig, axs = plt.subplots(1, 3, figsize=[15, 5])
axs[0].imshow(noisy_patches[0], cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')
axs[1].imshow(crpt_patch, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')
axs[2].imshow(mask, cmap='binary_r', aspect='auto')

axs[0].set_title('Original')
axs[1].set_title('Corrupted')
axs[2].set_title('Corruption Mask')
plt.tight_layout()
plt.show()


## Step 4: Train the Model

Now we'll use the `train` entrypoint to train the denoising model. This will save checkpoints to the output directory.


In [None]:
# Prepare the noisy data for training
# Save the noisy patches to a file that the training script can use
training_data_path = "tests/training_data.npy"
np.save(training_data_path, data)

# Set up training parameters
output_dir = "./checkpoints"
n_epochs = 20  # Reduced for notebook demo; use 150-200 for best results
n_training = 2048
n_test = 256
batch_size = 32

print("Starting training...")
print(f"Training for {n_epochs} epochs")
print(f"Output directory: {output_dir}")

# Run training using the entrypoint
result = subprocess.run(
    [
        sys.executable, "-m", "blindspot_denoise.train",
        "--data", training_data_path,
        "--output-dir", output_dir,
        "--n-epochs", str(n_epochs),
        "--n-training", str(n_training),
        "--n-test", str(n_test),
        "--batch-size", str(batch_size),
        "--hidden-channels", "32",
        "--levels", "2",
        "--num-noisy-traces", "5",
        "--num-realisations", "7",
        "--active-number", "15",
        "--noise-level", "0.25",
        "--seed", "42",
    ],
    capture_output=True,
    text=True
)

print(result.stdout)
if result.stderr:
    print("Errors:", result.stderr)


## Step 5: Visualize Training Metrics

Load and plot the training history to see how the model learned.


In [None]:
# Load training history
history_path = Path(output_dir) / "training_history.npz"
if history_path.exists():
    history = np.load(history_path)
    train_loss_history = history['train_loss']
    train_accuracy_history = history['train_accuracy']
    test_loss_history = history['test_loss']
    test_accuracy_history = history['test_accuracy']
    
    # Plot training metrics
    fig, axs = plt.subplots(1, 2, figsize=(15, 4))
    
    axs[0].plot(train_accuracy_history, 'r', lw=2, label='train')
    axs[0].plot(test_accuracy_history, 'k', lw=2, label='validation')
    axs[0].set_title('RMSE', size=16)
    axs[0].set_ylabel('RMSE', size=12)
    axs[0].legend()
    axs[0].set_xlabel('# Epochs', size=12)
    
    axs[1].plot(train_loss_history, 'r', lw=2, label='train')
    axs[1].plot(test_loss_history, 'k', lw=2, label='validation')
    axs[1].set_title('Loss', size=16)
    axs[1].set_ylabel('Loss', size=12)
    axs[1].legend()
    axs[1].set_xlabel('# Epochs', size=12)
    
    plt.tight_layout()
    plt.show()
else:
    print("Training history not found. Training may still be running or failed.")


## Step 6: Run Inference

Now we'll use the `infer` entrypoint to denoise a noisy sample. We'll create a new noisy realization and denoise it.


In [None]:
# Create a new noisy realization for testing (different from training)
test_sample_idx = 0
testdata = add_trace_wise_noise(
    data[test_sample_idx:test_sample_idx+1],
    num_noisy_traces=5,
    noisy_trace_value=0.0,
    num_realisations=1
)[0]

print(f"Test data shape: {testdata.shape}")

# Save test data
test_input_path = "tests/test_input.npy"
np.save(test_input_path, testdata)

# Visualize the noisy test data
plt.figure(figsize=[7, 5])
plt.imshow(testdata, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')
plt.title('Noisy Test Data')
plt.xlabel('Trace')
plt.ylabel('Time Sample')
plt.colorbar()
plt.tight_layout()
plt.show()


In [None]:
# Find the best model checkpoint (use final or last epoch)
model_path = Path(output_dir) / "denoise_final.net"
if not model_path.exists():
    # Try to find the last epoch checkpoint
    checkpoints = sorted(Path(output_dir).glob("denoise_ep*.net"))
    if checkpoints:
        model_path = checkpoints[-1]
        print(f"Using checkpoint: {model_path}")
    else:
        raise FileNotFoundError(f"No model checkpoint found in {output_dir}")

# Run inference
output_path = "tests/denoised_output.npy"
print(f"Running inference with model: {model_path}")
print(f"Input: {test_input_path}")
print(f"Output: {output_path}")

result = subprocess.run(
    [
        sys.executable, "-m", "blindspot_denoise.infer",
        "--model", str(model_path),
        "--input", test_input_path,
        "--output", output_path,
    ],
    capture_output=True,
    text=True
)

print(result.stdout)
if result.stderr:
    print("Errors:", result.stderr)


## Step 7: Visualize Results

Compare the original clean data, noisy data, denoised result, and the removed noise.


In [None]:
# Load results
clean_data = data[test_sample_idx]
denoised_data = np.load(output_path)

# Visualize denoising performance
fig, axs = plt.subplots(1, 4, figsize=[20, 5])
axs[0].imshow(clean_data, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
axs[1].imshow(testdata, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
axs[2].imshow(denoised_data, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
axs[3].imshow(testdata - denoised_data, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)

axs[0].set_title('Clean (Original)')
axs[1].set_title('Noisy')
axs[2].set_title('Denoised')
axs[3].set_title('Noise Removed')

for ax in axs:
    ax.set_xlabel('Trace')
    ax.set_ylabel('Time Sample')

plt.tight_layout()
plt.show()


In [None]:
# Calculate some quantitative metrics
mse_noisy = np.mean((clean_data - testdata)**2)
mse_denoised = np.mean((clean_data - denoised_data)**2)
psnr_noisy = -10 * np.log10(mse_noisy / (np.max(clean_data) - np.min(clean_data))**2)
psnr_denoised = -10 * np.log10(mse_denoised / (np.max(clean_data) - np.min(clean_data))**2)

print("Quantitative Metrics:")
print(f"  MSE (Noisy): {mse_noisy:.6f}")
print(f"  MSE (Denoised): {mse_denoised:.6f}")
print(f"  PSNR (Noisy): {psnr_noisy:.2f} dB")
print(f"  PSNR (Denoised): {psnr_denoised:.2f} dB")
print(f"  Improvement: {psnr_denoised - psnr_noisy:.2f} dB")


## Alternative: Direct Python API Usage

Instead of using the entrypoint scripts, you can also use the module functions directly in Python. This gives you more control over the training loop and allows for custom callbacks or modifications.


In [None]:
# Example: Using the module directly for custom inference
from blindspot_denoise.models import UNet
from blindspot_denoise.utils import set_seed

# Load model
device = 'cpu'
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"Using device: {device}")
else:
    print("Using CPU")

network = torch.load(model_path, map_location=device)
network.eval()

# Run inference directly
torch_testdata = torch.from_numpy(
    np.expand_dims(np.expand_dims(testdata, axis=0), axis=0)
).float()

with torch.no_grad():
    test_prediction = network(torch_testdata.to(device))
    test_pred = test_prediction.detach().cpu().numpy().squeeze()

print(f"Direct inference result shape: {test_pred.shape}")
print("This matches the entrypoint output!")


## Summary

This notebook demonstrated:
1. **Data Generation**: Using the test data generator to create synthetic seismic-like data
2. **Training**: Using the `train` entrypoint to train a blind-trace denoising model
3. **Inference**: Using the `infer` entrypoint to denoise noisy seismic data
4. **Visualization**: Comparing clean, noisy, and denoised results

The entrypoint scripts provide a clean, reproducible way to train and use the models, while the module functions can be imported directly for custom workflows.
