# 训练自己的超分模型

In [None]:
import os
import cv2
import numpy as np
import tensorflow as tf

# 忽略TensorFlow的AVX警告
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'




# 定义SRCNN模型

In [None]:

def srcnn_model():
    inputs = tf.keras.layers.Input(shape=(None, None, 3))
    conv1 = tf.keras.layers.Conv2D(64, (9, 9), activation='relu', padding='same')(inputs)
    conv2 = tf.keras.layers.Conv2D(32, (1, 1), activation='relu', padding='same')(conv1)
    outputs = tf.keras.layers.Conv2D(3, (5, 5), padding='same')(conv2)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model


# 编译SRCNN模型
def compile_srcnn_model(model):
    model.compile(optimizer='adam', loss='mean_squared_error')


# 加载Set5数据集
def load_set5_dataset():
    dataset_path = 'dataset'     #dataset就是你自己的数据集，训练的数据集。DIV2K、Flickr2K、BSDS500、COCO或者自建数据集
    images = []
    for filename in os.listdir(dataset_path):
        if filename.endswith('.png'):
            image_path = os.path.join(dataset_path, filename)
            image = cv2.imread(image_path)
            image = cv2.resize(image, (256, 256))  # 调整大小为相同的大小
            images.append(image)
    return np.array(images)


# 训练SRCNN模型
def train_srcnn_model(model, images, epochs=10, batch_size=16):
    model.fit(images, images, epochs=epochs, batch_size=batch_size, verbose=1)


# 保存SRCNN模型权重
def save_srcnn_weights(model, filename='srcnn.weights.h5'):

    model.save_weights(filename)



# 主函数

In [None]:


if __name__ == "__main__":
    # 加载数据集
    images = load_set5_dataset()

    # 创建并编译SRCNN模型
    model = srcnn_model()
    compile_srcnn_model(model)

    # 训练SRCNN模型
    train_srcnn_model(model, images)

    # 保存SRCNN模型权重
    save_srcnn_weights(model)
