In [1]:
import numpy as np
import torch


class AverageMeter(object):
    """Computes and stores the average and current value.
    Examples::
        >>> # Initialize a meter to record loss
        >>> losses = AverageMeter()
        >>> # Update meter after every minibatch update
        >>> losses.update(loss_value, batch_size)
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count



def convert_rgb_to_ycbcr(img):
    if type(img) == np.ndarray:
        y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
        cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
        cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
        return np.array([y, cb, cr]).transpose([1, 2, 0])
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
        cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
        cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
        return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
    else:
        raise Exception('Unknown Type', type(img))

def convert_ycbcr_to_rgb(img):
    if type(img) == np.ndarray:
        r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
        g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
        b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
        return np.array([r, g, b]).transpose([1, 2, 0])
    elif type(img) == torch.Tensor:
        if len(img.shape) == 4:
            img = img.squeeze(0)
        r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
        g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
        b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
        return torch.cat([r, g, b], 0).permute(1, 2, 0)
    else:
        raise Exception('Unknown Type', type(img))

In [2]:
import glob

import h5py
import PIL.Image as pImg
import numpy as np
from torch.utils.data import Dataset


def rgb2gray(img):
    return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.



# imgPath为图像路径；h5Path为存储路径；scale为放大倍数
# pSize为patch尺寸； pStride为步长
def setTrianData(imgPath, h5Path, scale=3, pSize=33, pStride=14):
    h5_file = h5py.File(h5Path, 'w')
    lrPatches, hrPatches = [], []  # 用于存储低分辨率和高分辨率的patch
    for p in sorted(glob.glob(f'{imgPath}/*')):
        hr = pImg.open(p).convert('RGB')
        lrWidth, lrHeight = hr.width // scale, hr.height // scale
        # width, height为可被scale整除的训练数据尺寸
        width, height = lrWidth * scale, lrHeight * scale
        hr = hr.resize((width, height), resample=pImg.BICUBIC)
        lr = hr.resize((lrWidth, lrHeight), resample=pImg.BICUBIC)
        lr = lr.resize((width, height), resample=pImg.BICUBIC)
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)
        hr = rgb2gray(hr)
        lr = rgb2gray(lr)
        # 将数据分割
        for i in range(0, height - pSize + 1, pStride):
            for j in range(0, width - pSize + 1, pStride):
                lrPatches.append(lr[i:i + pSize, j:j + pSize])
                hrPatches.append(hr[i:i + pSize, j:j + pSize])
    h5_file.create_dataset('lr', data=np.array(lrPatches))
    h5_file.create_dataset('hr', data=np.array(hrPatches))
    h5_file.close()


class DataSet(Dataset):
    def __init__(self, h5_file):
        super(Dataset, self).__init__()
        self.h5_file = h5_file

    def __getitem__(self, idx):
        with h5py.File(self.h5_file, 'r') as f:
            return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)

    def __len__(self):
        with h5py.File(self.h5_file, 'r') as f:
            return len(f['lr'])


# 生成数据集
# trainIamgePath = 'T91'
# testTamgePath = 'Set5'
# trainSavePath = 'T91.h5'
# testSavePath = 'Set5.h5'

# setTrianData(trainIamgePath, trainSavePath)
# setTrianData(testTamgePath, testSavePath)



In [None]:
import os
import copy
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from data_process import DataSet
import utils

# models.py
class SRCNN(nn.Module):
    def __init__(self, nChannel=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(nChannel, 64,
                               kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
        self.conv3 = nn.Conv2d(32, nChannel,
                               kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x


cudnn.benchmark = True
# 设置训练设备 是CPU还是cuda
# device = torch.device(
#   'cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
outPath = "outputs"
scale = 3
bSize = 16
nEpoch = 400
nWorker = 8  # 线程数
seed = 42  # 随机数种子

# 模型和设备
lr = 1e-4  # 学习率
torch.manual_seed(seed)  # 设置随机数种子
model = SRCNN().to(device)  # 将模型载入设备
criterion = nn.MSELoss()  # 设置损失函数
optimizer = optim.Adam([
    {'params': model.conv1.parameters()},
    {'params': model.conv2.parameters()},
    {'params': model.conv3.parameters(), 'lr': lr * 0.1}
], lr=lr)

trainFile = "T91.h5"
evalFile = "Set5.h5"

# 装载训练数据
trainData = DataSet(trainFile)
trainLoader = DataLoader(dataset=trainData,
                         batch_size=bSize,
                         shuffle=True)  # 表示打乱样本)

# 装载预测数据
evalDatas = DataSet(evalFile)
evalLoader = DataLoader(dataset=evalDatas, batch_size=1)


def initPSNR():
    return {'avg': 0, 'sum': 0, 'count': 0}


def updatePSNR(psnr, val, n=1):
    s = psnr['sum'] + val * n
    c = psnr['count'] + n
    return {'avg': s / c, 'sum': s, 'count': c}


bestWeights = copy.deepcopy(model.state_dict())  # 最佳模型
bestEpoch = 0  # 最佳训练结果
bestPSNR = 0.0  # 最佳psnr

# 训练主循环
for epoch in range(nEpoch):
    print('times:'+str(epoch))
    model.train()
    epochLosses = initPSNR()

    # 训练
    for data in trainLoader:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        preds = model(inputs)
        loss = criterion(preds, labels)
        epochLosses = updatePSNR(epochLosses, loss.item(), len(inputs))
        optimizer.zero_grad()  # 清空梯度
        loss.backward()  # 反向传播
        optimizer.step()  # 根据梯度更新网络参数

    print(str(epochLosses['avg']))
    torch.save(model.state_dict(), os.path.join(outPath, f'epoch_{epoch}.pth'))

    # 测试
    model.eval()  # 取消dropout
    psnr = utils.AverageMeter()
    for data in evalLoader:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # 令reqires_grad自动设为False，关闭自动求导
        # clamp将inputs归一化为0到1区间
        with torch.no_grad():
            preds = model(inputs).clamp(0.0, 1.0)

        tmp_psnr = 10. * torch.log10(1. / torch.mean((preds - labels) ** 2))
        psnr.update(tmp_psnr, len(inputs))

    print(f'eval psnr: {psnr.avg:.2f}')

    if psnr.avg > bestPSNR:
        bestEpoch = epoch
        bestPSNR = psnr.avg
        bestWeights = copy.deepcopy(model.state_dict())

print(f'best epoch: {bestEpoch}, psnr: {bestPSNR:.2f}')
torch.save(bestWeights, os.path.join(outPath, 'best_'+str(bestEpoch)+'_'+str(bestPSNR)+'.pth'))


# test

import argparse

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader

import PIL.Image as pil_image
from data_process import DataSet
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb


class SRCNN(nn.Module):
    def __init__(self, nChannel=1):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(nChannel, 64,
                               kernel_size=9, padding=9 // 2)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
        self.conv3 = nn.Conv2d(32, nChannel,
                               kernel_size=5, padding=5 // 2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        return x

parser = argparse.ArgumentParser()
parser.add_argument('--image-file', default='Set5/woman.png', type=str)

parser.add_argument('--scale', type=int, default=3)
args = parser.parse_args()

device = torch.device('cpu')
model = SRCNN()
model.load_state_dict(torch.load('outputs/best_83_36.28.pth'))
model.eval()

image = pil_image.open(args.image_file).convert('RGB')   # 将图片转为RGB类型
image_width = (image.width // 3) * 3
image_height = (image.height // 3) * 3
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
image = image.resize((image.width // 3, image.height // 3), resample=pil_image.BICUBIC)
image = image.resize((image.width * 3, image.height * 3), resample=pil_image.BICUBIC)
image.save(args.image_file.replace('.', '_bicubic_x{}.'.format(3)))
# 将图像转化为数组类型，同时图像转为ycbcr类型
image = np.array(image).astype(np.float32)
ycbcr = convert_rgb_to_ycbcr(image)
# 得到 ycbcr中的 y 通道
y = ycbcr[..., 0]
y /= 255.  # 归一化处理
y = torch.from_numpy(y).to(device)  # 把数组转换成张量，且二者共享内存，对张量进行修改比如重新赋值，那么原始数组也会相应发生改变，并且将参数放到device上
y = y.unsqueeze(0).unsqueeze(0)  # 增加两个维度
# 令reqires_grad自动设为False，关闭自动求导
# clamp将inputs归一化为0到1区间
with torch.no_grad():
    preds = model(y).clamp(0.0, 1.0)

# 1.mul函数类似矩阵.*，即每个元素×255
# 2. *.cpu（）.numpy（） 将数据的处理设备从其他设备（如gpu拿到cpu上），不会改变变量类型，转换后仍然是Tensor变量，同时将Tensor转化为ndarray
# 3. *.squeeze(0).squeeze(0)数据的维度进行压缩
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)  # 得到的是经过模型处理，取值在[0,255]的y通道图像

# 将img的数据格式由（channels,imagesize,imagesize）转化为（imagesize,imagesize,channels）,进行格式的转换后方可进行显示。
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])

output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(
    np.uint8)  # 将图像格式从ycbcr转为rgb，限制取值范围[0,255]，同时矩阵元素类型为uint8类型
output = pil_image.fromarray(output)  # array转换成image，即将矩阵转为图像
output.save(args.image_file.replace('.', '_srcnn_x{}.'.format(args.scale)))  # 对图像进行保存