In [None]:
import cv2
import torch
import numpy as np
import os
import sys
sys.path.append(os.path.abspath("MiDaS")) 
from midas.dpt_depth import DPTDepthModel
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
from tqdm import tqdm

# 配置路径
input_dir = 'autodl-tmp/afo/images1'         # AFO图像目录
output_dir = 'autodl-tmp/afo_fog/images1'    # 输出加雾图像
os.makedirs(output_dir, exist_ok=True)

# 加载 MiDaS 模型
model_type = "DPT_Hybrid"
# midas = torch.hub.load("intel-isl/MiDaS", model_type)
midas = torch.hub.load("intel-isl/MiDaS", model_type, trust_repo=True)

midas.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
midas.to(device)

# 加载 MiDaS 转换器
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

transform = midas_transforms.dpt_transform

def depth_to_transmission(depth, beta=0.08):
    """将深度图转为透射率图"""
    depth = cv2.normalize(depth, None, 0, 1, cv2.NORM_MINMAX)
    t = np.exp(-beta * depth)
    return np.clip(t, 0.05, 1.0)

def apply_fog(img, transmission, A=255):
    """合成雾图"""
    fog = img * transmission[:, :, None] + A * (1 - transmission[:, :, None])
    return np.clip(fog, 0, 255).astype(np.uint8)

# 遍历处理图像
for fname in tqdm(os.listdir(input_dir)):
    if not fname.endswith(('.jpg', '.png')): continue
    img_path = os.path.join(input_dir, fname)
    img = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # MiDaS深度估计
    input_tensor = transform(img_rgb).to(device)
    with torch.no_grad():
        prediction = midas(input_tensor)
        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=img.shape[:2],
            mode="bicubic",
            align_corners=False
        ).squeeze().cpu().numpy()

    # 转为透射率并合成雾图
    t = depth_to_transmission(prediction, beta=0.08)
    fog_img = apply_fog(img, t, A=255)

    # 保存结果
    cv2.imwrite(os.path.join(output_dir, fname), fog_img)


Using cache found in /root/.cache/torch/hub/intel-isl_MiDaS_master
Downloading: "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt" to /root/.cache/torch/hub/checkpoints/dpt_hybrid_384.pt
  3%|▎         | 13.4M/470M [03:47<2:56:28, 45.2kB/s]