In [None]:
%run GlobalConfig.ipynb
%run CNNModel.ipynb
%run ImageSR.ipynb
# %run LoadDataSet.ipynb

import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from dataset import MyDataset  # 从文件中导入
from dataset import transform  # 从文件中导入
import matplotlib.pyplot as plt
import threading
import time
from torch.optim.lr_scheduler import ReduceLROnPlateau
    
def train_model():
    # 创建数据集和数据加载器
    my_dataset = MyDataset(root_dir=global_train_dataset_path, scale_factor=global_scale_factor, crop_size=global_crop_size, transform=transform)
    my_loader = DataLoader(my_dataset, batch_size=global_batch_size, shuffle=True, num_workers=4)
    
    # 检查是否有GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # 初始化网络和优化器
    model = ConvLutModel().to(device)
    criterion_mse = nn.MSELoss().to(device)
    vgg = VGGFeatureExtractor().to(device)
    criterion_perceptual = lambda sr, hr: perceptual_loss(sr, hr, vgg)
    criterion_ssim = ssim_loss
    
    optimizer = optim.Adam(model.parameters(), lr=global_learn_rate, weight_decay=1e-5)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
    
    # 引入验证集
    val_dataset = MyDataset(root_dir=global_valid_dataset_path, scale_factor=global_scale_factor, crop_size=global_crop_size, transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=global_batch_size, shuffle=True, num_workers=4)
    
    # 训练网络
    num_epochs = 5000
    best_val_loss = float('inf')

    train_losses = []
    val_losses = []
    
    # def live_plot():
    #     while True:
    #         if len(train_losses) > 0 and len(val_losses) > 0:
    #             plt.clf()
    #             plt.plot(train_losses, label='Train Loss')
    #             plt.plot(val_losses, label='Validation Loss')
    #             plt.xlabel('Epoch')
    #             plt.ylabel('Loss')
    #             plt.legend()
    #             plt.pause(0.1)
    #         time.sleep(1)
            
    # plot_thread = threading.Thread(target=live_plot)
    # plot_thread.daemon = True
    # plot_thread.start()
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        for lr_img, hr_img in my_loader:
            lr_img = lr_img.squeeze().permute(0, 2, 3, 1)  # 调整维度顺序，适应批量处理
            hr_img = hr_img.squeeze().permute(0, 2, 3, 1)
            lr_img, hr_img = lr_img.to(device), hr_img.to(device)
            
            sr_img = apply_lut_batchs(lr_img, model, scale_factor=global_scale_factor, device=device)
            # sr_img = sr_img.permute(2, 0, 1).unsqueeze(0)  # (H, W, C) -> (C, H, W) -> (1, C, H, W)
            # hr_img = hr_img.permute(2, 0, 1).unsqueeze(0)  # (H, W, C) -> (C, H, W) -> (1, C, H, W)
            sr_img = sr_img.permute(0, 3, 1, 2)  # 调整回模型需要的维度顺序
            hr_img = hr_img.permute(0, 3, 1, 2)
            
            loss_mse = criterion_mse(sr_img, hr_img)
            loss_perceptual = criterion_perceptual(sr_img, hr_img)
            loss_ssim = criterion_ssim(sr_img, hr_img)
            loss = loss_mse + 0.1 * loss_perceptual + 0.1 * loss_ssim
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            epoch_loss += loss.item()
        
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for lr_img, hr_img in val_loader:
                lr_img = lr_img.squeeze().permute(0, 2, 3, 1)  # 调整维度顺序，适应批量处理
                hr_img = hr_img.squeeze().permute(0, 2, 3, 1)
                lr_img, hr_img = lr_img.to(device), hr_img.to(device)
                
                sr_img = apply_lut_batchs(lr_img, model, scale_factor=global_scale_factor, device=device)
    
                sr_img = sr_img.permute(0, 3, 1, 2)  # 调整回模型需要的维度顺序
                hr_img = hr_img.permute(0, 3, 1, 2)
                
                loss_mse = criterion_mse(sr_img, hr_img)
                loss_perceptual = criterion_perceptual(sr_img, hr_img)
                loss_ssim = criterion_ssim(sr_img, hr_img)
                loss = loss_mse + 0.1 * loss_perceptual + 0.1 * loss_ssim
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')

        train_losses.append(epoch_loss / len(my_loader))
        val_losses.append(val_loss)

        scheduler.step(val_loss)
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(my_loader):.4f}, Val Loss: {val_loss:.4f}')
    
    # 保存最终模型
    torch.save(model.state_dict(), 'final_model.pth')

    # 关闭绘图
    # plt.ioff()
    # plt.show()

if __name__ == '__main__':
    train_model()


Using device: cuda
Epoch [1/5000], Loss: 0.0678, Val Loss: 0.0276
Epoch [2/5000], Loss: 0.0244, Val Loss: 0.0211
Epoch [3/5000], Loss: 0.0203, Val Loss: 0.0165
Epoch [4/5000], Loss: 0.0191, Val Loss: 0.0167
Epoch [5/5000], Loss: 0.0182, Val Loss: 0.0156
Epoch [6/5000], Loss: 0.0161, Val Loss: 0.0152
Epoch [7/5000], Loss: 0.0164, Val Loss: 0.0145
Epoch [8/5000], Loss: 0.0164, Val Loss: 0.0129
Epoch [9/5000], Loss: 0.0146, Val Loss: 0.0131
Epoch [10/5000], Loss: 0.0141, Val Loss: 0.0124
Epoch [11/5000], Loss: 0.0145, Val Loss: 0.0128
Epoch [12/5000], Loss: 0.0144, Val Loss: 0.0110
Epoch [13/5000], Loss: 0.0139, Val Loss: 0.0108
Epoch [14/5000], Loss: 0.0139, Val Loss: 0.0133
Epoch [15/5000], Loss: 0.0137, Val Loss: 0.0120
Epoch [16/5000], Loss: 0.0137, Val Loss: 0.0111
Epoch [17/5000], Loss: 0.0137, Val Loss: 0.0123
Epoch [18/5000], Loss: 0.0133, Val Loss: 0.0127
Epoch [19/5000], Loss: 0.0131, Val Loss: 0.0117
Epoch [20/5000], Loss: 0.0136, Val Loss: 0.0117
Epoch [21/5000], Loss: 0.0133,