# Seismic Facies Classification Using Deep Learning

Implementation of:
**"A deep learning framework for seismic facies classification"**  
*Kaur et al., 2022, Interpretation*

This notebook demonstrates the complete workflow:
1. Data loading and preprocessing
2. Model training (DeepLabv3+ and GAN)
3. Model evaluation and testing
4. Uncertainty estimation
5. Visualization and comparison

---

## 1. Setup and Imports

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Import custom modules
from data_loader import (
    SeismicFaciesDataset,
    create_patches_from_volume,
    get_dataloaders,
    create_dummy_data
)

from model import (
    DeepLabV3Plus,
    GANSegmentation,
    get_model
)

from train import (
    Trainer,
    train_model
)

from test import (
    Tester,
    test_model,
    compare_models
)

from utils import (
    evaluate_model,
    estimate_uncertainty,
    visualize_prediction,
    plot_metrics,
    plot_confusion_matrix
)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

Set parameters according to the paper:
- Patch size: 200 Ã— 200
- Number of classes: 6 (facies types)
- Batch size: 32
- Epochs: 60
- Optimizer: Adam

In [None]:
# Configuration
CONFIG = {
    # Data parameters
    'patch_size': 200,          # As per paper
    'num_classes': 6,           # As per paper (6 facies types)
    'in_channels': 1,           # Grayscale seismic data
    
    # Training parameters
    'batch_size': 32,           # As per paper
    'num_epochs': 60,           # As per paper (for GAN)
    'learning_rate': 1e-4,      # Default Adam lr (not specified in paper)
    'num_workers': 4,
    
    # Uncertainty estimation
    'num_mc_samples': 20,       # MC dropout samples (not specified, using common value)
    
    # Paths
    'data_dir': './data',
    'checkpoint_dir': './checkpoints',
    'results_dir': './results'
}

# Create directories
for dir_path in [CONFIG['data_dir'], CONFIG['checkpoint_dir'], CONFIG['results_dir']]:
    Path(dir_path).mkdir(parents=True, exist_ok=True)

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## 3. Data Loading

### Option A: Load Real Seismic Data

If you have seismic data, load it here. Expected format:
- Seismic data: numpy array of shape (N, 200, 200)
- Labels: numpy array of shape (N, 200, 200) with values 0-5

```python
# Example:
train_seismic = np.load('path/to/train_seismic.npy')
train_labels = np.load('path/to/train_labels.npy')
val_seismic = np.load('path/to/val_seismic.npy')
val_labels = np.load('path/to/val_labels.npy')
test_seismic = np.load('path/to/test_seismic.npy')
test_labels = np.load('path/to/test_labels.npy')
```

In [None]:
# Option A: Load your real data (uncomment and modify if you have data)
# train_seismic = np.load('path/to/train_seismic.npy')
# train_labels = np.load('path/to/train_labels.npy')
# val_seismic = np.load('path/to/val_seismic.npy')
# val_labels = np.load('path/to/val_labels.npy')
# test_seismic = np.load('path/to/test_seismic.npy')
# test_labels = np.load('path/to/test_labels.npy')

### Option B: Generate Dummy Data for Testing

In [None]:
# Option B: Generate dummy data for demonstration
print("Generating dummy seismic data...")

# According to paper: 27,648 training patches
# For demo purposes, we use smaller numbers
train_seismic, train_labels = create_dummy_data(
    num_samples=1000,  # Use 27648 for full-scale training
    patch_size=CONFIG['patch_size'],
    num_classes=CONFIG['num_classes'],
    save_path=CONFIG['data_dir']
)

val_seismic, val_labels = create_dummy_data(
    num_samples=200,
    patch_size=CONFIG['patch_size'],
    num_classes=CONFIG['num_classes']
)

test_seismic, test_labels = create_dummy_data(
    num_samples=200,
    patch_size=CONFIG['patch_size'],
    num_classes=CONFIG['num_classes']
)

print(f"\nData shapes:")
print(f"  Train: {train_seismic.shape}")
print(f"  Val:   {val_seismic.shape}")
print(f"  Test:  {test_seismic.shape}")

### Visualize Sample Data

In [None]:
# Visualize some samples
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

facies_names = [
    'Basement rocks',
    'Slope mudstone A',
    'Mass-transport complex',
    'Slope mudstone B',
    'Slope valley',
    'Submarine canyon'
]

for i in range(3):
    # Seismic
    axes[0, i].imshow(train_seismic[i], cmap='seismic', aspect='auto')
    axes[0, i].set_title(f'Seismic Sample {i+1}')
    axes[0, i].axis('off')
    
    # Labels
    im = axes[1, i].imshow(train_labels[i], cmap='tab10', vmin=0, vmax=5, aspect='auto')
    axes[1, i].set_title(f'Facies Labels {i+1}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.savefig(Path(CONFIG['results_dir']) / 'sample_data.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nFacies classes:")
for i, name in enumerate(facies_names):
    print(f"  {i}: {name}")

### Create Data Loaders

In [None]:
# Create data loaders
train_loader, val_loader = get_dataloaders(
    train_seismic, train_labels,
    val_seismic, val_labels,
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers']
)

test_dataset = SeismicFaciesDataset(test_seismic, test_labels, normalize=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=False,
    num_workers=CONFIG['num_workers']
)

print(f"\nData loaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")
print(f"  Test batches:  {len(test_loader)}")

## 4. Model Training

### 4.1 Train DeepLabv3+

In [None]:
# Train DeepLabv3+
print("\n" + "="*70)
print("TRAINING DEEPLABV3+")
print("="*70)

deeplab_history = train_model(
    model_type='deeplabv3+',
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=CONFIG['num_epochs'],
    learning_rate=CONFIG['learning_rate'],
    device=device,
    checkpoint_dir=CONFIG['checkpoint_dir']
)

### 4.2 Train GAN

In [None]:
# Train GAN
print("\n" + "="*70)
print("TRAINING GAN")
print("="*70)

gan_history = train_model(
    model_type='gan',
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=CONFIG['num_epochs'],
    learning_rate=CONFIG['learning_rate'],
    device=device,
    checkpoint_dir=CONFIG['checkpoint_dir']
)

### 4.3 Plot Training History

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
axes[0].plot(deeplab_history['train_loss'], label='DeepLabv3+ Train', linewidth=2)
axes[0].plot(deeplab_history['val_loss'], label='DeepLabv3+ Val', linewidth=2, linestyle='--')
axes[0].plot(gan_history['train_loss'], label='GAN Train', linewidth=2)
axes[0].plot(gan_history['val_loss'], label='GAN Val', linewidth=2, linestyle='--')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14)
axes[0].legend(fontsize=10)
axes[0].grid(alpha=0.3)

# F1 score curves
deeplab_train_f1 = [m['mean_f1'].item() for m in deeplab_history['train_metrics']]
deeplab_val_f1 = [m['mean_f1'].item() for m in deeplab_history['val_metrics']]
gan_train_f1 = [m['mean_f1'].item() for m in gan_history['train_metrics']]
gan_val_f1 = [m['mean_f1'].item() for m in gan_history['val_metrics']]

axes[1].plot(deeplab_train_f1, label='DeepLabv3+ Train', linewidth=2)
axes[1].plot(deeplab_val_f1, label='DeepLabv3+ Val', linewidth=2, linestyle='--')
axes[1].plot(gan_train_f1, label='GAN Train', linewidth=2)
axes[1].plot(gan_val_f1, label='GAN Val', linewidth=2, linestyle='--')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('F1 Score', fontsize=12)
axes[1].set_title('Training and Validation F1 Score', fontsize=14)
axes[1].legend(fontsize=10)
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig(Path(CONFIG['results_dir']) / 'training_history.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Model Evaluation

### 5.1 Evaluate DeepLabv3+

In [None]:
# Test DeepLabv3+
deeplab_checkpoint = Path(CONFIG['checkpoint_dir']) / 'deeplabv3+_best.pth'

deeplab_metrics = test_model(
    model_type='deeplabv3+',
    checkpoint_path=str(deeplab_checkpoint),
    test_loader=test_loader,
    device=device,
    save_dir=Path(CONFIG['results_dir']) / 'deeplabv3+',
    visualize=True,
    estimate_uncertainty=True,
    num_mc_samples=CONFIG['num_mc_samples']
)

### 5.2 Evaluate GAN

In [None]:
# Test GAN
gan_checkpoint = Path(CONFIG['checkpoint_dir']) / 'gan_best.pth'

gan_metrics = test_model(
    model_type='gan',
    checkpoint_path=str(gan_checkpoint),
    test_loader=test_loader,
    device=device,
    save_dir=Path(CONFIG['results_dir']) / 'gan',
    visualize=True,
    estimate_uncertainty=True,
    num_mc_samples=CONFIG['num_mc_samples']
)

### 5.3 Compare Models

As described in the paper, we perform a comparative analysis of DeepLabv3+ and GAN results.

In [None]:
# Compare models side-by-side
comparison_results = compare_models(
    deeplab_checkpoint=str(deeplab_checkpoint),
    gan_checkpoint=str(gan_checkpoint),
    test_loader=test_loader,
    device=device,
    save_dir=Path(CONFIG['results_dir']) / 'comparison'
)

## 6. Uncertainty Analysis

As described in the paper, we use Monte Carlo Dropout for epistemic uncertainty estimation.

In [None]:
# Detailed uncertainty analysis
print("\n" + "="*70)
print("UNCERTAINTY ANALYSIS")
print("="*70)

# Load model
model = get_model('deeplabv3+', in_channels=1, num_classes=6)
tester = Tester(model, 'deeplabv3+', device=device)
tester.load_checkpoint(str(deeplab_checkpoint))

# Get a batch for analysis
seismic_batch, labels_batch = next(iter(test_loader))

# Predict with uncertainty
print(f"\nEstimating uncertainty with {CONFIG['num_mc_samples']} MC samples...")
predictions, uncertainty = tester.predict_with_uncertainty(
    seismic_batch,
    num_samples=CONFIG['num_mc_samples']
)

# Visualize uncertainty for first sample
seismic_np = seismic_batch[0, 0].numpy()
label_np = labels_batch[0].numpy()
pred_np = predictions[0].numpy()
uncert_np = uncertainty[0].numpy()

fig, axes = plt.subplots(2, 2, figsize=(12, 12))

axes[0, 0].imshow(seismic_np, cmap='seismic', aspect='auto')
axes[0, 0].set_title('Seismic Section', fontsize=14)
axes[0, 0].axis('off')

axes[0, 1].imshow(label_np, cmap='tab10', vmin=0, vmax=5, aspect='auto')
axes[0, 1].set_title('True Labels', fontsize=14)
axes[0, 1].axis('off')

axes[1, 0].imshow(pred_np, cmap='tab10', vmin=0, vmax=5, aspect='auto')
axes[1, 0].set_title('Predicted Labels', fontsize=14)
axes[1, 0].axis('off')

im = axes[1, 1].imshow(uncert_np, cmap='hot', aspect='auto')
axes[1, 1].set_title('Epistemic Uncertainty', fontsize=14)
axes[1, 1].axis('off')
plt.colorbar(im, ax=axes[1, 1], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig(Path(CONFIG['results_dir']) / 'uncertainty_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nUncertainty statistics:")
print(f"  Mean: {uncert_np.mean():.4f}")
print(f"  Std:  {uncert_np.std():.4f}")
print(f"  Min:  {uncert_np.min():.4f}")
print(f"  Max:  {uncert_np.max():.4f}")

## 7. Summary and Key Findings

Based on the paper's findings:
- **DeepLabv3+**: Produces sharper boundaries between facies
- **GAN**: Better continuity of predicted facies
- **Joint Analysis**: Combining predictions from both networks provides more accurate interpretation
- **Uncertainty**: High uncertainty regions often correspond to facies boundaries or mispredicted areas

In [None]:
# Print final summary
print("\n" + "="*70)
print("FINAL SUMMARY")
print("="*70)

print("\nDeepLabv3+ Performance:")
print(f"  Mean Precision: {deeplab_metrics['mean_precision']:.4f}")
print(f"  Mean Recall:    {deeplab_metrics['mean_recall']:.4f}")
print(f"  Mean F1 Score:  {deeplab_metrics['mean_f1']:.4f}")

print("\nGAN Performance:")
print(f"  Mean Precision: {gan_metrics['mean_precision']:.4f}")
print(f"  Mean Recall:    {gan_metrics['mean_recall']:.4f}")
print(f"  Mean F1 Score:  {gan_metrics['mean_f1']:.4f}")

print("\nKey Observations:")
print("  - DeepLabv3+ captures sharper facies boundaries (ASPP + encoder-decoder)")
print("  - GAN provides better facies continuity (adversarial training)")
print("  - Uncertainty is higher at facies boundaries and mispredicted regions")
print("  - Joint analysis of both models recommended for interpretation")

print("\n" + "="*70)
print(f"All results saved to: {CONFIG['results_dir']}")
print("="*70)

## 8. Inference on New Data (Optional)

Use trained models to predict facies on new seismic data.

In [None]:
# Example: Predict on a new seismic volume
# Uncomment and modify if you have new data to predict

# # Load new seismic data
# new_seismic = np.load('path/to/new_seismic.npy')  # Shape: (N, 200, 200)

# # Predict using DeepLabv3+
# model = get_model('deeplabv3+', in_channels=1, num_classes=6)
# tester = Tester(model, 'deeplabv3+', device=device)
# tester.load_checkpoint(str(deeplab_checkpoint))

# predictions, uncertainty = tester.predict_full_volume(
#     new_seismic,
#     batch_size=CONFIG['batch_size'],
#     estimate_uncertainty_flag=True,
#     num_mc_samples=CONFIG['num_mc_samples']
# )

# # Save predictions
# np.save(Path(CONFIG['results_dir']) / 'predictions.npy', predictions)
# np.save(Path(CONFIG['results_dir']) / 'uncertainty.npy', uncertainty)

# print(f"Predictions saved to {CONFIG['results_dir']}")

---

## References

Kaur, H., Pham, N., Fomel, S., Geng, Z., Decker, L., Gremillion, B., Jervis, M., Abma, R., & Gao, S. (2022). A deep learning framework for seismic facies classification. *Interpretation*, 11(1), T107-T116.

---