In [1]:
import h5py
import numpy as np
import megengine.data as data
import megengine.data.transform as T
# import megengine
import cv2
import os
from megengine.data.dataset import Dataset
import random

# import megengine as mge
# import megengine.module as M
# import megengine.functional as F

# from megengine.data.transform import ToMode
# from megengine.data import DataLoader, RandomSampler

import torch
import torchvision
import torch.utils.data as Data
# import torch.nn as M
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib import cm

from tqdm.notebook import tqdm
from torch.autograd import Variable
import wandb

In [12]:
def psnr(img1, img2):
    mse = F.mse_loss(img1, img2)
#     mse = torch.mean((img1 - img2) ** 2)
    return 20 * torch.log10(1.0 / torch.sqrt(mse)).item()

In [2]:
class cfg:
    folder = "/root/autodl-tmp/SIDD_Small_Raw_Only/"
    bs = 16 # batch size
    cut_size = 256
    device = 'cuda:0'
    load_pretrain = False
    lr = 1e-3
    n_steps = 10
    gamma = 0.7
    start_epoch = 0
    n_epochs = 10000
    model_name = "our_model"

In [3]:
def toTensor(img):
    img = torch.from_numpy(img)
    return img.float()

In [4]:
class SIDDDataset(Dataset):
    def __init__(self):
        self.image_folder = cfg.folder + "Data/"
        self.image_list = os.listdir(self.image_folder)
        self.cut = cfg.cut_size

    # get the sample
    def __getitem__(self, idx):
        # get the index
        image_id = self.image_list[idx]
        inoisy = h5py.File(self.image_folder + self.image_list[idx] + "/NOISY_RAW_010.MAT")
        igt = h5py.File(self.image_folder + self.image_list[idx] + "/NOISY_RAW_010.MAT")
        
        inoisy = np.float32(np.array(inoisy['x']).T)
        igt = np.float32(np.array(igt['x']).T)
        w_origin, h_origin = igt.shape
        
        x = random.randint(0, w_origin - self.cut - 1)
        y = random.randint(0, h_origin - self.cut - 1)
        
        inoisy = inoisy[np.newaxis, x:x+self.cut, y:y+self.cut]
        igt = igt[np.newaxis, x:x+self.cut, y:y+self.cut]
        
        return toTensor(inoisy), toTensor(igt)

    def __len__(self):
        return len(self.image_list)

In [5]:
from torch.utils.data import random_split

In [6]:
dataset = SIDDDataset()
train_set, vali_set = random_split(dataset=dataset, lengths=[148, 12])

train_dataloader = Data.DataLoader(train_set,
      shuffle=True, batch_size=cfg.bs
)
vali_dataloader = Data.DataLoader(vali_set,
      shuffle=False, batch_size=1
)

In [7]:
class Predictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(4, 50, 3, padding = 1, bias = True),
            nn.LeakyReLU(negative_slope = 0.125),
            nn.Conv2d(50, 50, 3, padding = 1, bias = True),
            nn.LeakyReLU(negative_slope = 0.125),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(50, 50, 3, padding = 1, bias = True),
            nn.LeakyReLU(negative_slope = 0.125),
            nn.Conv2d(50, 50, 3, padding = 1, bias = True),
            nn.LeakyReLU(negative_slope = 0.125),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(50, 50, 3, padding = 1, bias = True),
            nn.LeakyReLU(negative_slope = 0.125),
            nn.Conv2d(50, 4, 3, padding = 1, bias = True),
            nn.LeakyReLU(negative_slope = 0.125),
        )
    def forward(self, x):
        n, c, h, w = x.shape
        x = x.reshape((n, c, h // 2, 2, w // 2, 2)).permute(0, 1, 3, 5, 2, 4).reshape((n, c * 4, h // 2, w // 2))
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.reshape((n, c, 2, 2, h // 2, w // 2)).permute(0, 1, 4, 2, 5, 3).reshape((n, c, h, w))
        return x

In [10]:
def train(train_loader, vali_loader):
    model = Predictor().to(cfg.device)
    
    # 加载预训练权重
    if cfg.load_pretrain:
        # TODO
        s = 0
        
    
    criterion_L1 = torch.nn.L1Loss().to(cfg.device)
    optimizer = torch.optim.Adam([paras for paras in model.parameters() if paras.requires_grad == True], lr=cfg.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.n_steps, gamma=cfg.gamma)

    
    loss_list = []
    
    # 开始训练
    for idx_epoch in range(cfg.start_epoch, cfg.n_epochs):
        # record loss
        loss_epoch = []
        wandb.log({"epoch": idx_epoch})
        for idx_iter, (inoisy, igt) in tqdm(enumerate(train_loader)):
            inoisy, igt = Variable(inoisy).to(cfg.device), Variable(igt).to(cfg.device)
            ipred = model(inoisy)
            loss = criterion_L1(ipred, igt)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_epoch.append(loss.data.cpu())
            
        scheduler.step()
        lr_now = scheduler.get_last_lr()
        wandb.log({"lr": float(lr_now[-1])})
        
        # 输出结果
        if idx_epoch % 1 == 0:
            print('Epoch--%4d, loss--%f' %
                  (idx_epoch + 1, float(np.array(loss_epoch).mean())))
            wandb.log({"loss": float(np.array(loss_epoch).mean())})
            
        # 保存模型
        if idx_epoch % 5 == 0:
            torch.save({'epoch': idx_epoch + 1, 'state_dict': model.state_dict()},
                       'log/' + cfg.model_name + '_' + 'epoch' + str(idx_epoch + 1) + '.pth.tar')
            
        # validation
        if idx_epoch % 1 == 0:
            psnr_list = []
            for idx_iter, (inoisy, igt) in tqdm(enumerate(vali_loader)):
                inoisy, igt = Variable(inoisy).to(cfg.device), Variable(igt).to(cfg.device)
                with torch.no_grad():
                    ipred = model(inoisy)
                    
                p = psnr(inoisy, ipred)
                psnr_list.append(p)
            print("Tested PSNR", str(sum(psnr_list) / 12))
            wandb.log({"PSNR": float(sum(psnr_list) / 12)})

In [None]:
train_cfg = {
                "batch size": cfg.bs,
                "cut_size":cfg.cut_size,
                "num_epoch":cfg.n_epochs,
                "init_lr":cfg.lr,
                "n_steps":cfg.n_steps,
                "change_gamma":cfg.gamma
}

wandb.init(project="blind-denoising",
          config = train_cfg)
train(train_dataloader, vali_dataloader)




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁
loss,▁
lr,▁

0,1
epoch,0.0
loss,0.05628
lr,0.001


0it [00:00, ?it/s]

In [None]:
class Autoencoder(M.Module):
    def __init__(self):
        super().__init__()
        self.encoder = M.Sequential([
            # Input(shape=(28, 28, 1,)),
            M.Conv2D(4, 50, 3, padding = 1, bias = True),
            M.MaxPooling2D((2, 2), padding='same'),
            M.Conv2D(32, (3, 3), padding='same', activation='relu',dilation_rate=(3,3)),
            M.MaxPooling2D((1, 1), padding='same'),
            M.Conv2D(32, (3, 3), padding='same', activation='relu',dilation_rate=(4,4)),
            M.MaxPooling2D((1, 1), padding='same'),
            M.Conv2D(32, (3, 3), padding='same', activation='relu',dilation_rate=(4,4)),
            M.MaxPooling2D((2, 2), padding='same'),
        ])
        self.decoder=M.Sequential([
            M.Conv2D(32, (3, 3), padding='same', activation='relu',dilation_rate=(2,2)),
            M.UpSampling2D((2, 2)),
            M.Conv2D(32, (3, 3), padding='same', activation='relu',dilation_rate=(3,3)),
            M.UpSampling2D((2, 2)),
            M.Conv2D(32, (3, 3), padding='same', activation='relu',dilation_rate=(4,4)),
            M.UpSampling2D((1, 1)),
            M.Conv2D(32, (3, 3), padding='same', activation='relu',dilation_rate=(4,4)),
            M.UpSampling2D((1, 1)),
            M.Conv2D(1, (3, 3), padding='same', activation='sigmoid')
        ])
    def forward(self, inputs):
        e=self.encoder(inputs)
        y=self.decoder(e)
        return encoder,decoder

In [None]:
# 检查参数量
model = Predictor()
# print(model)
autoencoder=Autoencoder()

In [None]:
# from megengine.utils.module_stats import module_stats

# input_data = np.random.rand(1, 1, 256, 256).astype("float32")
# total_stats, stats_details = module_stats(
#     net,
#     inputs = (input_data,),
#     cal_params = True,
#     cal_flops = True,
#     logging_to_stdout = True,
# )

# print("params %.3fK MAC/pixel %.0f"%(total_stats.param_dims/1e3, total_stats.flops/input_data.shape[2]/input_data.shape[3]))