Skip to content

itYangYYi/CBA-DiffU_Net

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CBA-DiffU_Net

Occam2DMT 二维大地电磁反演的深度学习加速框架

本项目面向 Occam2DMT 二维大地电磁(2D MT)反演流程,使用深度神经网络学习「中间反演结果 → 最终收敛模型」的非线性映射,从而在保持反演精度的前提下显著降低传统迭代反演所需的计算开销。

仓库同时提供两条互补的建模路线:

  • 判别式路线:UNet / UNet++ / UNet++ with CBAM,一次前向推理即得结果,速度快,适合批量场景;
  • 生成式路线:基于 DDPM 的条件扩散模型,以平滑反演结果为条件逐步去噪,更适合恢复锐利的电阻率突变边界。

两条路线共享同一套数据管线、网格模板与可视化工具,可按需切换。


特性

  • 同时覆盖 U-Net 家族和条件扩散模型两套架构,一个仓库内完成对照实验
  • 针对 20×64 物理网格定制的对称填充与裁剪机制,避免非 2 幂次分辨率引发的对齐误差
  • 两种面向稀疏异常体的损失函数:空间梯度 L1 损失 与 边缘加权 L1 损失
  • 数据集支持一次性预加载到内存,训练期间零 I/O
  • 推理脚本自动识别模型架构(U-Net / U-Net++ / CBAM / Diffusion)
  • 提供基于线程池的批量预测脚本,支持对整个测试集高通量推理

项目结构

CBA-DiffU_Net/
├── data/                              # 数据集根目录(默认被 .gitignore 忽略)
│   └── auto_training_data/            # 训练样本集合
├── models/                            # 网络结构与损失函数
│   ├── unet/
│   │   ├── unet_model.py              # 基础 UNet64x20
│   │   ├── unet_plus_model.py         # UNet++ 64x20
│   │   └── unet_plus_model_CBAM.py    # 融合 CBAM 的 UNet++
│   ├── diffusion/
│   │   ├── diffusion_config.py        # 扩散调度/网络/训练超参集中配置
│   │   ├── gaussian_diffusion.py      # DDPM 前向加噪调度器
│   │   └── unet_plus_model_CBAM_diffusion.py  # 带时间嵌入的条件去噪 U-Net
│   └── loss/
│       ├── GradientL1Loss.py          # 空间梯度 L1 损失
│       └── EdgeWeightedL1Loss.py      # 边缘加权 L1 损失
├── scripts/                           # 训练与推理脚本
│   ├── train_unet_model.py            # U-Net 家族训练入口(一键切换架构)
│   ├── train_diffusion_model.py       # 条件扩散模型训练入口
│   ├── predict_and_plot.py            # 单样本推理并调用可视化
│   └── batch_predict.py               # 整个测试集的并行批量预测
├── utils/                             # 数据、正演与可视化工具
│   ├── dataloader.py                  # 自定义数据集与预加载逻辑
│   ├── dataset_analyzer.py            # 数据集统计与梯度分析
│   ├── Occam2DMT_python_000_initial/  # 作者自研的 Python 正演/反演/绘图管线
│   └── OCCAM2DMT_V3.0/                # 第三方 Occam2DMT 原版(见「第三方依赖」)
├── templates/                         # 网格与初始模型模板
│   ├── occam.mesh                     # 标准网格定义
│   ├── occam.model                    # 初始模型配置
│   └── startup.iter                   # 初始参考模型
├── runs/                              # 训练与推理输出(默认被 .gitignore 忽略)
│   ├── train/                         # 训练权重与损失曲线
│   └── predict/                       # 推理结果与可视化
├── .gitignore
├── LICENSE
├── README.md
└── requirements.txt

环境配置

本项目基于 Conda 虚拟环境构建,推荐按以下流程配置以确保 CUDA 与 PyTorch 之间的 ABI 对齐。

核心依赖:

  • Python 3.13.11
  • PyTorch 2.10.0(配套 CUDA 13.0)
  • torchvision 0.25.0
  • NumPy 2.4.1
  • SciPy 1.17.0
  • Pandas 3.0.0
  • Matplotlib 3.10.8
  • Pillow 12.1.0
  • tqdm 4.67.3

快速构建:

# 创建并激活专属 Conda 环境
conda create -n Occam2DMT_python python=3.13
conda activate Occam2DMT_python

# 安装适配 CUDA 13.0 的 PyTorch
pip install torch==2.10.0+cu130 torchvision==0.25.0+cu130 \
    --index-url https://download.pytorch.org/whl/cu130

# 安装其余依赖
pip install -r requirements.txt

如果没有 NVIDIA GPU,可将上面的 PyTorch 安装替换为 CPU 版本:pip install torch==2.10.0 torchvision==0.25.0。本仓库所有脚本会自动探测 CUDA 可用性并回退到 CPU。


数据格式说明

训练数据结构

每个样本必须保持如下目录层级与命名规范:

sample_XXXX/
├── startup.iter                # 初始参考模型(作为网络的真实标签 Target)
├── inversion/
│   ├── both00.iter             # 第 0 次迭代过程文件
│   ├── both01.iter             # 第 1 次迭代过程文件
│   └── ...
│   └── bothNN.iter             # 最终收敛结果(作为网络输入 Input)
├── occam.mesh                  # 当前样本对应的网格文件
└── occam.model                 # 当前样本对应的模型配置文件

数据集加载器 (utils/dataloader.py) 会自动在 inversion/ 下选取序号最大的 both*.iter 作为输入,并将 startup.iter 作为标签。

.iter 文件解析约定

Occam 反演程序输出的 *.iter 文件包含完整的反演上下文:

  • 文件头:反演参数、平滑度、数据拟合差等上下文信息;
  • 参数主体:电阻率值,默认以 Log10 形式存储;
  • 网格尺寸:本项目严格适配 20 行 × 64 列的二维网格,加载器会截取前 1280 个有效参数并重构为 2D 张量。

使用指南

1. 训练 U-Net 家族模型

训练脚本内置了三种网络架构,可通过修改脚本顶部的 MODEL_CHOICE 一键切换:

模型代号 架构 适用场景
UNet 对称式编解码网络 基线(Baseline)性能对比
UNet++ 引入密集嵌套跳跃连接的编解码网络 需要更精细的多尺度地质特征融合
UNet++_CBAM 在 UNet++ 基础上融合 CBAM 注意力 对高对比度异常体定位精度要求最高时的首选

启动训练(包含自适应学习率衰减与早停机制):

python scripts/train_unet_model.py

默认配置下:

  • 自动按 20% 比例切分验证集;
  • 使用 ReduceLROnPlateau 在验证损失停滞时衰减学习率;
  • 触发 Early Stopping(默认 patience=15)时自动回滚到最佳权重;
  • 最优权重保存为 runs/train/<model>_best.pth,同时生成训练曲线 PNG。

2. 训练条件扩散模型

条件扩散模型以平滑的反演结果作为去噪引导条件,网络预测的不是图像本身,而是每个时间步被注入的高斯噪声:

python scripts/train_diffusion_model.py

超参数集中在 models/diffusion/diffusion_config.py 中管理,包括:

  • num_timesteps:扩散总步数(默认 1000);
  • beta_schedule / beta_start / beta_end:噪声调度策略;
  • in_channels:输入通道数(带噪图 x_t + 条件图 c,共 2 通道);
  • model_channels / channel_mult / time_emb_dim:去噪网络结构参数。

训练过程启用梯度裁剪(max_norm=1.0)以避免扩散模型早期阶段梯度爆炸。最优权重保存为 runs/train/occam_diffusion_best.pth

3. 单样本推理与可视化

对单个 .iter 中间文件执行前向推理,并自动调用绘图脚本生成剖面图:

python scripts/predict_and_plot.py <input_iter_file> \
    --model <model_path> \
    --output <output_iter_path>

predict_and_plot.py 会根据权重文件名中的关键字(diffusion / cbam / plusplus)自动匹配对应的网络结构;对扩散模型则自动进入 DDPM 逆向采样循环。

可视化选项:

  • --linear:采用线性颜色尺度,替代默认的对数映射;
  • --mesh:在剖面图上叠加网格线;
  • --core:只显示核心反演区域,屏蔽边缘冗余。

4. 测试集批量预测

对测试集目录下所有样本并行推理,按原目录结构生成预测结果与图像:

python scripts/batch_predict.py \
    --test_data <test_dir> \
    --model <model_path> \
    --output <output_dir> \
    --workers 4

扩散模型的单次推理需要完整运行 DDPM 逆向采样循环,显存与算力开销都显著高于判别式模型,建议 --workers 控制在 2–4;而对 U-Net 家族可以放宽到 16 或更高。


模型架构说明

1. UNet64x20(基础架构)

专为 64×20 物理分辨率定制的基座模型。通过 4 层下采样(特征通道 64 → 512)提取高维抽象特征,再通过对称上采样恢复空间分辨率。

为了避免 20 这一非 2 幂次维度在池化过程中的对齐问题,网络内部使用对称常数填充 + 精准裁剪:前向传播初期将张量扩展为 32×64 的标准尺寸,输出前再裁剪回 20×64,确保物理维度无损。

2. UNetPlusPlus64x20(密集连接架构)

U-Net++ 在编码器与解码器之间构建了一个密集的嵌套跳跃连接网络。相比 U-Net 的单一长距离跨层拼接,密集短连接让浅层的边缘细节与高频空间信息能够逐层平滑传递到输出端,缓解重复池化导致的特征模糊。

3. AttentionUNetPlusPlus64x20(融合 CBAM)

在 U-Net++ 的密集特征融合基础上引入 CBAM(Convolutional Block Attention Module),由两部分组成:

  • 通道注意力:结合全局平均池化与最大池化,为不同特征通道分配重要性权重;
  • 空间注意力:在 2D 平面上生成权重掩码,引导网络聚焦于电阻率突变区域。

这有助于在稀疏异常体检测任务中减少背景噪声带来的无效响应。

4. Conditional DDPM(条件扩散模型)

扩散模型将训练目标从「直接回归电阻率图像」改为「在每一个时间步预测被注入的噪声」。前向过程按 DDPM 标准调度向真实标签 x_0 注入高斯噪声;逆向过程中,网络以当前带噪图 x_t 和平滑反演结果 c 的通道拼接为输入,并通过时间步嵌入感知当前所处的去噪阶段。

相比一次前向的判别式模型,扩散模型在恢复锐利边界与细小异常体的形态上具有天然优势,代价是推理需要完整跑完整个逆向采样链。


损失函数设计

大地电磁反演任务的空间特征高度不平衡:绝大多数像素是平滑背景,真正具有诊断意义的异常体只占极少比例。此时普通的像素级 L1/L2 损失容易陷入「全局抹平」的优化陷阱,本项目提供两种针对性的复合损失函数。

A. Gradient L1 Loss(空间梯度惩罚)

在像素级 L1 损失之上,额外惩罚预测图像在水平与垂直方向的梯度偏差:

$$ L_{\text{total}} = L_{1}(Y, \hat{Y}) + \alpha \left( L_{1}(\nabla_x Y, \nabla_x \hat{Y}) + L_{1}(\nabla_y Y, \nabla_y \hat{Y}) \right) $$

这一机制要求网络不仅要拟合电阻率的绝对数值,还需复刻局部区域的边界走向,对中大型异常体有明显的锐化效果。

B. Edge-Weighted L1 Loss(边缘加权 — 默认启用)

为避免梯度惩罚在极小异常体上可能引发的「双重惩罚」,本项目进一步提出基于真实标签梯度的动态加权策略:

$$ \text{Weight}(x, y) = 1.0 + \lambda \cdot \text{TargetGradient}(x, y) $$

$$ \text{Loss} = \frac{1}{N} \sum \text{Weight}(x, y) \cdot \left| Y(x, y) - \hat{Y}(x, y) \right| $$

在该框架下,平滑背景区域的误差权重保持在 1,而位于异常体边缘的像素的误差会被放大若干倍(默认上限约 16 倍),从而促使模型把有限的表达能力集中到真正需要锐化的区域。


第三方依赖

本仓库 utils/OCCAM2DMT_V3.0/ 目录下保留了 Occam2DMT v3.0 的原版源码与二进制,由加州大学圣地亚哥分校 Scripps 海洋研究所的 Marine EM Laboratory 发布,版权归原作者所有。使用前请阅读并遵守其原始许可条款,项目主页见:

https://marineemlab.ucsd.edu/Projects/Occam/index.html

本仓库其余代码采用 MIT 许可证,详见 LICENSE


参考文献

  1. U-Net: Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional networks for biomedical image segmentation. MICCAI, 234–241. https://doi.org/10.1007/978-3-319-24574-4_28
  2. U-Net++: Zhou, Z., Siddiquee, M. M. R., Tajbakhsh, N., & Liang, J. (2018). UNet++: A nested U-Net architecture for medical image segmentation. DLMIA, 3–11. https://doi.org/10.1007/978-3-030-00889-5_1
  3. CBAM: Woo, S., Park, J., Lee, J. Y., & Kweon, I. S. (2018). CBAM: Convolutional block attention module. ECCV, 3–19. https://doi.org/10.1007/978-3-030-01234-2_1
  4. DDPM: Ho, J., Jain, A., & Abbeel, P. (2020). Denoising diffusion probabilistic models. NeurIPS, 33, 6840–6851.
  5. Occam 二维反演: Constable, S. C., Parker, R. L., & Constable, C. G. (1987). Occam's inversion: A practical algorithm for generating smooth models from electromagnetic sounding data. Geophysics, 52(3), 289–300. https://doi.org/10.1190/1.1442303
  6. Adam 优化器: Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv:1412.6980.

许可证

本项目采用 MIT License。第三方代码 (utils/OCCAM2DMT_V3.0/) 请遵守其各自的许可条款。

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages