In [3]:
from lib.utils import *
import torch.backends.cudnn as cudnn
import torch
from torch import nn
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from lib.models import SRResNet
from lib.dataloaders import SRDataset
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import time
from PIL import Image

In [4]:
create_data_lists(train_folders=['./data/COCO2014/train2014',
                                     './data/COCO2014/val2014'],
                      test_folders=['./data/BSD100',
                                    './data/Set5',
                                    './data/Set14'],
                      min_size=100,
                      output_folder='./data/')


正在创建文件列表... 请耐心等待.

训练集中共有 123285 张图像

在测试集 BSD100 中共有 100 张图像

在测试集 Set5 中共有 5 张图像

在测试集 Set14 中共有 14 张图像

生成完毕。训练集和测试集文件列表已保存在 ./data/ 下



In [2]:
# 数据集参数
data_folder = './data/'          # 数据存放路径
crop_size = 96      # 高分辨率图像裁剪尺寸
scaling_factor = 4  # 放大比例

# 模型参数
large_kernel_size = 9   # 第一层卷积和最后一层卷积的核大小
small_kernel_size = 3   # 中间层卷积的核大小
n_channels = 64         # 中间层通道数
n_blocks = 16           # 残差模块数量

# 学习参数
checkpoint = None   # 预训练模型路径，如果不存在则为None
batch_size = 96    # 批大小 win11 4G显存下 batch_size=96  GPU_memory=3.7/4.0GB
start_epoch = 1     # 轮数起始位置
epochs = 20        # 迭代轮数
workers = 4         # 工作线程数
lr = 1e-4           # 学习率

# 设备参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ngpu = 1           # 用来运行的gpu数量

cudnn.benchmark = True # 对卷积进行加速

In [8]:
def train(writer,model_name):
    
    global checkpoint,start_epoch

    # 初始化
    model = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    # 初始化优化器
    optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=lr)

    # 迁移至默认设备进行训练
    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    # 加载预训练模型
    if checkpoint is not None:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    
    if torch.cuda.is_available() and ngpu > 1:
        model = nn.DataParallel(model, device_ids=list(range(ngpu)))

    # 定制化的dataloaders
    train_dataset = SRDataset(data_folder,split='train',
                              crop_size=crop_size,
                              scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm',
                              hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers,
        pin_memory=True) 

    # 开始逐轮训练
    for epoch in range(start_epoch, epochs+1):

        model.train() 

        loss_epoch = AverageMeter()  # 统计损失函数

        n_iter = len(train_loader)

        # 按批处理
        for i, (lr_imgs, hr_imgs) in enumerate(train_loader):

            # 数据移至默认设备进行训练
            lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24), imagenet-normed 格式
            hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96),  [-1, 1]格式

            # 前向传播
            sr_imgs = model(lr_imgs)

            # 计算损失
            loss = criterion(sr_imgs, hr_imgs)  

            # 后向传播
            optimizer.zero_grad()
            loss.backward()

            # 更新模型
            optimizer.step()

            # 记录损失值
            loss_epoch.update(loss.item(), lr_imgs.size(0))

            # 监控图像变化
            if i==(n_iter-2):
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_1', make_grid(lr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_2', make_grid(sr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)
                writer.add_image('SRResNet/epoch_'+str(epoch)+'_3', make_grid(hr_imgs[:4,:3,:,:].cpu(), nrow=4, normalize=True),epoch)

            # 打印结果
            print("第"+str(epoch)+"/"+str(epochs)+"个epoch,第"+str(i)+"/"+str(len(train_loader))+" 个batch训练结束    ",end='\r')
 
        # 手动释放内存              
        del lr_imgs, hr_imgs, sr_imgs

        # 监控损失值变化
        writer.add_scalar('SRResNet/MSE_Loss', loss_epoch.val, epoch)    

        # 保存预训练模型
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, 'results/'+model_name)
    
    # 训练结束关闭监控
    writer.close()

In [None]:
writer = SummaryWriter() # 实时监控     使用命令 tensorboard --logdir runs  进行查看
model_name='srresnet.pth'
train(writer,model_name)

In [1]:
# 双线性上采样
def Bicubic(imgDir,input,output):
        # 加载图像
    img = Image.open(imgDir+input, mode='r')
    img = img.convert('RGB')
 
    Bicubic_img = img.resize((int(img.width * scaling_factor),int(img.height * scaling_factor)),Image.BICUBIC)
    Bicubic_img.save(imgDir+output)

#超分处理
def super_revolution(model,imgDir,input,output):
    # 加载模型SRResNet 或 SRGAN
    checkpoint = torch.load(model)
    srresnet = SRResNet(large_kernel_size=large_kernel_size,
                        small_kernel_size=small_kernel_size,
                        n_channels=n_channels,
                        n_blocks=n_blocks,
                        scaling_factor=scaling_factor)
    srresnet = srresnet.to(device)
    srresnet.load_state_dict(checkpoint['model'])
   
    srresnet.eval()
    model = srresnet
 
    # 加载图像
    img = Image.open(imgDir+input, mode='r')
    img = img.convert('RGB')
 
    # 图像预处理
    lr_img = convert_image(img, source='pil', target='imagenet-norm')
    lr_img.unsqueeze_(0)
 
    # 记录时间
    start = time.time()
 
    # 转移数据至设备
    lr_img = lr_img.to(device)  # (1, 3, w, h ), imagenet-normed
 
    # 模型推理
    with torch.no_grad():
        sr_img = model(lr_img).squeeze(0).cpu().detach()  # (1, 3, w*scale, h*scale), in [-1, 1]   
        sr_img = convert_image(sr_img, source='[-1, 1]', target='pil')
        sr_img.save(imgDir+output)
 
    print(output+'用时  {:.3f} 秒'.format(time.time()-start))
 

In [2]:
model = "./results/srresnet.pth"
imgDir='./results/'
Bicubic(imgDir=imgDir,input='butterfly_GT.bmp',output='butterfly_x4_bicubic.jpg')
Bicubic(imgDir=imgDir,input='butterfly_x4_bicubic.jpg',output='butterfly_x16_bicubic.jpg')

NameError: name 'Image' is not defined

In [6]:
super_revolution(model,imgDir=imgDir,input='butterfly_GT.bmp',output='butterfly_x4_sr.jpg')

butterfly_x4_sr.jpg用时  2.702 秒


In [8]:
#显存不够用了
device = torch.device("cpu")
super_revolution(model,imgDir=imgDir,input='butterfly_x4_sr.jpg',output='butterfly_x16_sr.jpg')

butterfly_x16_sr.jpg用时  168.429 秒
