In [None]:
# 设置环境变量解决OpenMP问题
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image

# 限制PyTorch线程数
torch.set_num_threads(1)

project_path = Path(os.getcwd()).resolve()  # 获取当前工作目录
print(f"项目路径: {project_path}")

if str(project_path) not in sys.path:
    sys.path.append(str(project_path))

# 直接导入文件
from generator import Generator
from discriminator import Discriminator
from dataset import UnderwaterDataset

def verify_dataset(data_dir):
    """验证数据集结构和内容"""
    data_path = Path(data_dir)
    
    print(f"\n验证数据集:")
    print(f"路径: {data_path}")
    print(f"路径存在: {data_path.exists()}")
    
    if not data_path.exists():
        print("错误: 数据集路径不存在")
        return False
    
    # 检查images目录
    images_dir = data_path / 'images'
    if not images_dir.exists():
        print("错误: images目录不存在")
        return False
    
    # 获取所有图片
    images = []
    for ext in ['*.jpg', '*.JPG', '*.jpeg', '*.JPEG']:
        images.extend(images_dir.glob(ext))
    
    total_images = len(images)
    print(f"\n找到图片: {total_images} 张")
    
    if images:
        print(f"示例: {images[0].name}")
        # 测试图片加载
        try:
            with Image.open(images[0]) as img:
                print(f"图片尺寸: {img.size}")
        except Exception as e:
            print(f"图片加载失败: {e}")
    
    return total_images > 0

class Config:
    def __init__(self):
        # 基础配置
        self.epochs = 10
        self.batch_size = 8
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.image_size = 256
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # 路径配置
        self.base_dir = Path(project_path)
        self.data_dir = self.base_dir / "train_data"  # 数据集目录
        self.checkpoint_dir = self.base_dir / "checkpoints"
        self.results_dir = self.base_dir / "results"
        
        # 创建必要的目录
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.results_dir.mkdir(parents=True, exist_ok=True)

def train():
    # 创建配置
    cfg = Config()
    print(f"使用设备: {cfg.device}")
    
    # 验证数据集
    if not verify_dataset(cfg.data_dir):
        raise RuntimeError("数据集验证失败")
    
    # 创建数据转换
    transform = transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ColorJitter(
            brightness=0.1,  # 亮度调整
            contrast=0.1,    # 对比度调整
            saturation=0.1,  # 饱和度调整
            hue=0.05        # 色调调整
        ),
        transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
        transforms.RandomVerticalFlip(p=0.5),    # 随机垂直翻转
        transforms.RandomRotation(10),           # 随机旋转
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # 创建数据集
    print("\n创建数据集...")
    try:
        dataset = UnderwaterDataset(cfg.data_dir, transform=transform)
        print(f"数据集大小: {len(dataset)}")
        
        # 测试第一张图片加载
        if len(dataset) > 0:
            print("\n测试第一张图片加载:")
            first_image = dataset[0]
            print(f"图片尺寸: {first_image.shape}")
            print(f"数值范围: [{first_image.min():.2f}, {first_image.max():.2f}]")
    except Exception as e:
        print(f"创建数据集失败: {e}")
        raise
    
    # 创建数据加载器
    print("\n创建数据加载器...")
    try:
        dataloader = DataLoader(
            dataset, 
            batch_size=cfg.batch_size, 
            shuffle=True,
            num_workers=0,  # 使用单进程
            drop_last=True,
            pin_memory=False  # 禁用pin_memory
        )
        print(f"数据加载器创建成功，批次数: {len(dataloader)}")
    except Exception as e:
        print(f"创建数据加载器失败: {e}")
        raise

    # 初始化模型
    generator = Generator().to(cfg.device)
    discriminator = Discriminator().to(cfg.device)

    # 定义损失函数和优化器
    criterion_GAN = nn.BCELoss()
    criterion_pixel = nn.L1Loss()
    optimizer_G = optim.Adam(generator.parameters(), lr=cfg.lr, betas=(cfg.b1, cfg.b2))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=cfg.lr, betas=(cfg.b1, cfg.b2))

    # 训练循环
    print("\n开始训练...")
    latest_model_path = None  # 用于记录最新模型路径

    for epoch in range(cfg.epochs):
        # 使用单行进度条
        progress_bar = tqdm(dataloader, 
                          desc=f"Epoch {epoch}/{cfg.epochs}",
                          ncols=80,  # 设置进度条宽度
                          leave=False)  # 完成后删除进度条
        
        # 记录每个epoch的平均损失
        epoch_g_loss = 0
        epoch_d_loss = 0
        batch_count = 0
        
        for i, imgs in enumerate(progress_bar):
            # 准备数据
            real = torch.ones((imgs.size(0), 1, 16, 16), device=cfg.device)
            fake = torch.zeros((imgs.size(0), 1, 16, 16), device=cfg.device)
            real_imgs = imgs.to(cfg.device)
            
            # 训练生成器和判别器
            optimizer_G.zero_grad()
            gen_imgs = generator(real_imgs)
            g_loss = criterion_GAN(discriminator(gen_imgs), real)
            pixel_loss = criterion_pixel(gen_imgs, real_imgs)
            g_total_loss = g_loss + 100 * pixel_loss
            g_total_loss.backward()
            optimizer_G.step()
            
            optimizer_D.zero_grad()
            real_loss = criterion_GAN(discriminator(real_imgs), real)
            fake_loss = criterion_GAN(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()
            
            # 更新损失统计
            epoch_g_loss += g_total_loss.item()
            epoch_d_loss += d_loss.item()
            batch_count += 1
            
            # 更新进度条描述
            progress_bar.set_postfix({
                'G_loss': f'{g_total_loss.item():.2f}',
                'D_loss': f'{d_loss.item():.2f}'
            })
        
        # 打印每个epoch的平均损失
        avg_g_loss = epoch_g_loss / batch_count
        avg_d_loss = epoch_d_loss / batch_count
        print(f"Epoch {epoch}/{cfg.epochs} - "
              f"Avg G_loss: {avg_g_loss:.4f}, "
              f"Avg D_loss: {avg_d_loss:.4f}")
        
        # 保存检查点和生成示例
        if (epoch + 1) % 10 == 0:
            # 保存模型并记录路径
            generator_path = cfg.checkpoint_dir / f'generator_epoch_{epoch+1}.pth'
            discriminator_path = cfg.checkpoint_dir / f'discriminator_epoch_{epoch+1}.pth'
            
            torch.save(generator.state_dict(), generator_path)
            torch.save(discriminator.state_dict(), discriminator_path)
            
            latest_model_path = generator_path  # 记录最新的生成器模型路径
            print(f"\n模型已保存: {latest_model_path}")
            
            # 生成示例图像
            with torch.no_grad():
                sample_imgs = real_imgs[:4]
                gen_imgs = generator(sample_imgs)
                
                # 转换为显示格式
                gen_imgs = gen_imgs * 0.5 + 0.5
                sample_imgs = sample_imgs * 0.5 + 0.5
                
                # 创建对比图
                fig, axs = plt.subplots(2, 4, figsize=(16, 8))
                for j in range(4):
                    axs[0, j].imshow(sample_imgs[j].cpu().permute(1, 2, 0))
                    axs[0, j].axis('off')
                    axs[0, j].set_title('Original')
                    
                    axs[1, j].imshow(gen_imgs[j].cpu().permute(1, 2, 0))
                    axs[1, j].axis('off')
                    axs[1, j].set_title('Enhanced')
                
                plt.savefig(cfg.results_dir / f'comparison_epoch_{epoch+1}.png')
                plt.close()
                
    # 训练结束后打印最终模型路径
    if latest_model_path:
        print(f"\n最终模型保存在: {latest_model_path}")

# 验证环境
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA是否可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"当前GPU: {torch.cuda.get_device_name(0)}")

# 开始训练
if __name__ == "__main__":
    try:
        train()
    except Exception as e:
        print(f"\n训练失败: {e}")
        import traceback
        traceback.print_exc()

In [None]:
import torch
from torchvision import transforms
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from generator import Generator
import os
import cv2
import numpy as np
import pandas as pd
from skimage.metrics import peak_signal_noise_ratio as psnr
import math
import scipy.stats

class Tester:
    def __init__(self):
        # 配置
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.image_size = 256
        
        # 路径设置
        self.base_dir = Path(project_path)
        self.model_path = self.base_dir / "checkpoints" / "generator_epoch_20.pth"
        self.test_dir = self.base_dir / "test_data" / "images"  # 测试图片目录
        self.results_dir = self.base_dir / "test_results"
        self.excel_path = self.base_dir / "results" / "evaluation_results.xls"
        
        # 验证路径
        if not self.base_dir.exists():
            raise RuntimeError(f"基础目录不存在: {self.base_dir}")
        if not self.model_path.exists():
            raise RuntimeError(f"模型文件不存在: {self.model_path}")
        if not self.test_dir.exists():
            raise RuntimeError(f"测试目录不存在: {self.test_dir}")
            
        # 创建结果目录
        self.results_dir.mkdir(parents=True, exist_ok=True)
        self.excel_path.parent.mkdir(parents=True, exist_ok=True)
        
        # 图像转换
        self.transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        
        # 加载模型
        print(f"加载模型: {self.model_path}")
        self.generator = Generator().to(self.device)
        self.generator.load_state_dict(torch.load(self.model_path))
        self.generator.eval()
    
    def enhance_image(self, img_array):
        """增强图像处理"""
        # 转换到LAB空间进行处理
        lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        
        # 自适应直方图均衡化
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        l = clahe.apply(l)
        
        # 合并通道
        lab = cv2.merge((l,a,b))
        enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
        
        # 细节增强
        enhanced = cv2.detailEnhance(enhanced, sigma_s=10, sigma_r=0.15)
        
        # 调整对比度和亮度
        alpha = 1.1  # 对比度
        beta = 5    # 亮度
        enhanced = cv2.convertScaleAbs(enhanced, alpha=alpha, beta=beta)
        
        return enhanced
        
    def process_image(self, image_path):
        """处理单张图片"""
        # 加载并转换图片
        img = Image.open(image_path).convert('RGB')
        original_size = img.size
        
        # 转换为tensor
        input_tensor = self.transform(img).unsqueeze(0).to(self.device)
        
        # 生成增强图片
        with torch.no_grad():
            enhanced = self.generator(input_tensor)
        
        # 基础转换
        enhanced = enhanced * 0.5 + 0.5
        enhanced = enhanced.squeeze().cpu().permute(1, 2, 0).numpy()
        enhanced = (enhanced * 255).clip(0, 255).astype('uint8')
        
        # 应用额外的图像增强
        enhanced = self.enhance_image(enhanced)
        
        # 去噪处理
        enhanced = cv2.fastNlMeansDenoisingColored(enhanced, None, 10, 10, 7, 21)
        
        # 转回PIL图像
        enhanced_img = Image.fromarray(enhanced)
        
        # 调整回原始大小
        if original_size != (self.image_size, self.image_size):
            enhanced_img = enhanced_img.resize(original_size, Image.LANCZOS)
        
        return enhanced_img

    def calculate_psnr(self, original, enhanced):
        """计算PSNR"""
        return psnr(original, enhanced)
    
    def calculate_uciqe(self, img):
        """计算UCIQE"""
        lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        chroma = np.sqrt(np.square(a) + np.square(b))
        
        uc = np.mean(chroma)
        sc = np.sqrt(np.mean(np.square(chroma - uc)))
        cont = np.max(l) - np.min(l)
        
        uciqe = 0.4680 * sc + 0.2745 * cont + 0.2576 * uc
        
        return uciqe
    
    def calculate_uiqm(self, img):
        """计算UIQM"""
        hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
        h, s, v = cv2.split(hsv)
        
        uicm = np.mean(s)
        uism = cv2.Laplacian(v, cv2.CV_64F).var()
        uiconm = np.std(v)
        
        uiqm = 0.0282 * uicm + 0.2953 * uism + 0.6765 * uiconm
        
        return uiqm
    
    def test_all(self):
        """测试目录中的所有图片"""
        print(f"\n开始测试...")
        test_images = sorted(list(self.test_dir.glob('*.jpg')))
        results = []
        
        for img_path in test_images:
            print(f"处理图片: {img_path.name}")
            
            # 处理图片
            enhanced_img = self.process_image(img_path)
            
            # 计算指标
            original = np.array(Image.open(img_path).convert('RGB'))
            enhanced = np.array(enhanced_img)
            
            psnr_value = self.calculate_psnr(original, enhanced)
            uciqe_value = self.calculate_uciqe(enhanced)
            uiqm_value = self.calculate_uiqm(enhanced)
            
            results.append({
                'image file name': img_path.name,
                'PSNR': round(psnr_value, 4),
                'UCIQE': round(uciqe_value, 4),
                'UIQM': round(uiqm_value, 4)
            })
            
            print(f"PSNR: {round(psnr_value, 4)}")
            print(f"UCIQE: {round(uciqe_value, 4)}")
            print(f"UIQM: {round(uiqm_value, 4)}\n")
            
            # 保存对比图
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
            
            # 显示原图
            original_img = Image.open(img_path).convert('RGB')
            ax1.imshow(original_img)
            ax1.set_title('Original')
            ax1.axis('off')
            
            # 显示增强图
            ax2.imshow(enhanced_img)
            ax2.set_title('Enhanced')
            ax2.axis('off')
            
            # 保存对比图
            plt.savefig(self.results_dir / f'comparison_{img_path.stem}.png', 
                       bbox_inches='tight', dpi=300)
            plt.close()
            
            # 单独保存增强图
            enhanced_img.save(self.results_dir / f'enhanced_{img_path.stem}.png')
        
        # 创建DataFrame并保存结果
        df = pd.DataFrame(results)
        
        try:
            # 读取现有Excel文件的所有工作表
            excel = pd.read_excel(self.excel_path, sheet_name=None)
            
            # 更新 attachment 2 results 工作表数据
            excel['attachment 2 results'] = df
            
            # 保存所有工作表回Excel
            with pd.ExcelWriter(self.excel_path, engine='openpyxl', mode='w') as writer:
                for sheet_name, data in excel.items():
                    data.to_excel(writer, sheet_name=sheet_name, index=False)
            
            print(f"\n数据已成功更新到 {self.excel_path} 的 attachment 2 results 工作表")
            
        except Exception as e:
            print(f"更新Excel文件时出错: {str(e)}")
            # 保存为CSV备份
            csv_path = self.results_dir / "evaluation_results.csv"
            df.to_csv(csv_path, index=False)
            print(f"结果已保存到CSV文件: {csv_path}")
        
        print(f"\n测试完成! 结果保存在: {self.results_dir}")
        print("\n评估结果概要:")
        print(df.describe())

if __name__ == "__main__":
    tester = Tester()
    tester.test_all()