# 实验名称

In [2]:
# import package
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import torchvision.transforms as transforms

import random
import time
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# 设置中文支持
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False

## 数据准备

In [None]:
# custom datasets
class CustomDataset(Dataset):
    def __init__(self, data_folder, transform=None):
        self.data_folder = data_folder
        self.transform = transform
        
        # 将传入目录下的文件名转换为列表
        self.classes = os.listdir(data_folder)
        # 建立类别索引
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        # 数据成员变量，其元素均为元组数据类型，每一个元组构成为（图像路径名，类别索引）
        self.data = self.load_data()
        # 对数据进行打乱
        random.shuffle(self.data)

    def load_data(self):
        data = []
        for cls_name in self.classes:
            # 按类别获取图像文件夹的路径
            cls_folder = os.path.join(self.data_folder, cls_name)
            # 过滤文件后缀并存储同类图像的路径为列表
            images = [os.path.join(cls_folder, img_name) for img_name in os.listdir(cls_folder) if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.gif'))]
            # 获取该类图像的类别索引
            class_idx = self.class_to_idx[cls_name]
            # 将可迭代的列表添加到data中
            data.extend([(img_path, class_idx) for img_path in images])
        
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, class_idx = self.data[idx]
        img = Image.open(img_path).convert("RGB")
        
        if self.transform:
            img = self.transform(img)
        
        return img, class_idx

# Data folder path
data_folder = r"Data folder path"

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

dataset = CustomDataset(data_folder, transform=transform)
# 划分数据集
total_samples = len(dataset)
train_samples = int(0.7 * total_samples)
test_samples = total_samples - train_samples

# 划分训练集和测试集
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_samples, test_samples])

# 定义batch_size
batch_size = 32

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print("训练集大小：", len(train_dataset))
print("测试集大小：", len(test_dataset))

In [None]:
# 验证数据集分类正确性
for i in range(5):
    plt.figure()
    print(train_dataset[i+34][1])
    plt.imshow(transforms.ToPILImage()(train_dataset[i+34][0]))
plt.show()

## 函数准备

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

### 绘图函数

In [None]:
# 绘制训练损失和测试损失的函数
def plot_loss(train_loss, test_loss, title):
    plt.figure(figsize=(8, 4), dpi=100)
    x = np.arange(len(train_loss))
    plt.plot(x, train_loss, label="train loss", color="red", marker='v', markersize=5, linewidth=2)
    plt.plot(x, test_loss, label="test loss", color="blue", marker='o', markersize=5, linewidth=2)

    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    plt.show()

# 绘制训练精度和测试精度的函数
def plot_acc(train_acc, test_acc, title):
    plt.figure(figsize=(8, 4), dpi=100)
    x = np.arange(len(train_acc))
    plt.plot(x, train_acc, label="train acc", color="red", marker='v', markersize=5, linewidth=2)
    plt.plot(x, test_acc, label="test acc", color="blue", marker='o', markersize=5, linewidth=2)

    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title(title)
    plt.legend()
    plt.show()

# 绘制损失值对比函数
def loss_comparison(losses, labels, title):
    markers = ['o', 'v', 'p', '*', 'h', 'H', '+', 'x', 'D', 'd']
    plt.figure(figsize=(8, 4), dpi=100)
    x = np.arange(len(losses[0]))
    for idx, loss in enumerate(losses):
        plt.plot(x, loss, label=labels[idx], marker=markers[idx], markersize=5, linewidth=2)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    plt.show()

# 定义训练时间对比图
def train_time_comparison(times, labels, title):
    plt.figure(dpi=100)
    bars = plt.bar(labels, times, color=['blue', 'green', 'red'])

    plt.title(title)
    plt.xlabel('Models')
    plt.ylabel('Training Time (seconds)')

    # 在每个柱子上方显示时间数据
    for bar, time in zip(bars, times):
        plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 5, str(time), ha='center', color='black', fontsize=12)

    plt.tight_layout()
    plt.show()

### 模型训练函数

In [None]:
# 训练函数
def train_model(model, data_loader, criterion, optimizer, device):
    model.train()
    train_batch_num = len(data_loader)
    total_loss = 0
    correct = 0     # 记录分类正确数
    sample_num = 0  # 记录样本总数

    # 遍历每个batch进行训练
    for batch_idx, (data, target) in enumerate(data_loader):
        # 将图片放入指定的device中
        data, target = data.to(device).float(), target.to(device).long()

        # 梯度清零
        optimizer.zero_grad()
        # 前向传播
        output = model(data)
        # 计算损失
        loss = criterion(output, target)
        # 反向传播
        loss.backward()
        # 更新梯度
        optimizer.step()
        
        # 累加loss
        total_loss += loss.item()
        prediction = torch.argmax(output, 1)
        # 统计正确数
        correct += (prediction == target).sum().item()
        # 累加当前样本数量
        sample_num += len(prediction)

    loss = total_loss / train_batch_num
    acc = correct / sample_num
    return loss, acc

# 测试函数
def test_model(model, data_loader, criterion, device):
    model.eval()
    test_batch_num = len(data_loader)
    total_loss = 0
    correct = 0
    sample_num = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            data, target = data.to(device).float(), target.to(device).long()
            output = model(data)
            loss = criterion(output, target)

            total_loss += loss.item()

            prediction = torch.argmax(output, 1)
            correct += (prediction == target).sum().item()
            sample_num += len(prediction)

    loss = total_loss / test_batch_num
    acc = correct / sample_num
    return loss, acc

# 模型训练过程函数
def train(model, train_loader, test_loader, criterion, optimizer, epochs, device):
    # 训练模型并验证
    train_losses = []
    train_acc_list = []
    test_losses = []
    test_acc_list = []
    start = time.time()

    # 进行训练
    for epoch in range(epochs):
        # 在训练集上训练
        train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, device=device)

        # 在测试集上训练
        test_loss, test_acc = test_model(model, test_loader, criterion, device=device)

        train_losses.append(train_loss)
        train_acc_list.append(train_acc)
        test_losses.append(test_loss)
        test_acc_list.append(test_acc)

        print(f'Epoch: {epoch + 1}/{epochs},\t train_loss: {train_loss:.4f},\t train_acc: {train_acc:.4f},\t test_loss: {test_loss:.4f},\t test_acc: {test_acc:.4f}')

    end = time.time()
    print(f'\n训练结毕，耗时：{end - start}s')
    return train_losses, train_acc_list, test_losses, test_acc_list


### 模型定性分析函数

In [None]:
# 定义模型定性分析的函数
def qualitative_Analysis(model, datasets, mdoelPath):
    checkpoint = torch.load(mdoelPath)  # 替换为你的模型检查点文件路径
    model.load_state_dict(checkpoint['model_state'])
    model.eval()  # 设置模型为评估模式

    # 使用模型进行预测
    with torch.no_grad():
        output = []
        output1 = model(datasets[0][0].unsqueeze(0).to(device)).squeeze()
        output2 = model(datasets[1][0].unsqueeze(0).to(device)).squeeze()
        output.append(output1)
        output.append(output2)

    # 抽样显示训练集中图片
    fig, axes = plt.subplots(2, 3)
    fig.subplots_adjust(wspace=0.1, hspace=0)
    fig.suptitle("带雾图——模型输出图——去雾图", fontsize=10)
    for i in range(2):
        axes[i][0].imshow(transforms.ToPILImage()(datasets[i][0]))
        axes[i][0].axis('off')
        axes[i][1].imshow(transforms.ToPILImage()(output[i]))
        axes[i][1].axis('off')
        axes[i][2].imshow(transforms.ToPILImage()(datasets[i][1]))
        axes[i][2].axis('off')

    plt.show()

## 实验内容

In [None]:
# 模型搭建
class ModuleName(nn.Module):
    def __init__(self):
        super(ModuleName, self).__init__()
        # 
    
    def forward(self, X):
        output = 

        return output

In [None]:
# test model output
model = ModuleName().to(device)
output = model(train_dataset[1][0].unsqueeze_(0).to(device)).squeeze()
output.shape

In [None]:
# define hyperparameters and model
epochs = 20
lr = 0.001

model = ModuleName().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# train model
train_losses, train_acc_list, test_losses, test_acc_list = train(model, train_loader, test_loader, criterion, optimizer, epochs, device)

# 保存模型的状态字典和其他信息到文件
model_state = model.state_dict()
other_info = {'epoch': epochs, 'train_losses': train_losses, 'train_acc_list':train_acc_list, 'test_losses': test_losses, 'test_acc_list':test_acc_list}
torch.save({'model_state': model_state, 'other_info': other_info}, 'modelPath.pth')

In [None]:
# 模型损失变化图
plot_loss(train_losses, test_losses, title="title")

In [None]:
# 模型训练精度变化图
plot_acc(train_acc_list, test_acc_list, title='title')