Imports

In [None]:
import torch
import data as Data
import model as Model
import argparse
import logging
import core.logger as Logger
import core.metrics as Metrics
from core.wandb_logger import WandbLogger
from tensorboardX import SummaryWriter
import os
import numpy as np

from io import BytesIO
import lmdb
from PIL import Image
from torch.utils.data import Dataset
import random
import data.util as Util

Config

In [None]:
opt = {
    "name": "distributed_high_sr_ffhq",
    "phase": "train",  # overridden or defaulted from CLI
    "gpu_ids": [0, 1],  # from config (unless overridden by CLI)
    "debug": False,  # default from CLI
    "enable_wandb": False,  # default from CLI
    "log_wandb_ckpt": False,  # default from CLI
    "log_eval": False,  # default from CLI
    "path": {
        "log": "logs",
        "tb_logger": "tb_logger",
        "results": "results",
        "checkpoint": "checkpoint",
        "resume_state": None
    },
    "datasets": {
        "train": {
            "name": "FFHQ",
            "mode": "HR",
            "dataroot": "dataset/ffhq_64_512",
            "datatype": "img",
            "l_resolution": 64,
            "r_resolution": 512,
            "batch_size": 2,
            "num_workers": 8,
            "use_shuffle": True,
            "data_len": -1
        },
        "val": {
            "name": "CelebaHQ",
            "mode": "LRHR",
            "dataroot": "dataset/celebahq_64_512",
            "datatype": "img",
            "l_resolution": 64,
            "r_resolution": 512,
            "data_len": 50
        }
    },
    "model": {
        "which_model_G": "sr3",
        "finetune_norm": False,
        "unet": {
            "in_channel": 6,
            "out_channel": 3,
            "inner_channel": 64,
            "norm_groups": 16,
            "channel_multiplier": [1, 2, 4, 8, 16],
            "attn_res": [],
            "res_blocks": 1,
            "dropout": 0
        },
        "beta_schedule": {
            "train": {
                "schedule": "linear",
                "n_timestep": 2000,
                "linear_start": 1e-6,
                "linear_end": 1e-2
            },
            "val": {
                "schedule": "linear",
                "n_timestep": 2000,
                "linear_start": 1e-6,
                "linear_end": 1e-2
            }
        },
        "diffusion": {
            "image_size": 512,
            "channels": 3,
            "conditional": True
        }
    },
    "train": {
        "n_iter": 1000000,
        "val_freq": 1e4,
        "save_checkpoint_freq": 1e4,
        "print_freq": 50,
        "optimizer": {
            "type": "adam",
            "lr": 3e-6
        },
        "ema_scheduler": {
            "step_start_ema": 5000,
            "update_ema_every": 1,
            "ema_decay": 0.9999
        }
    },
    "wandb": {
        "project": "distributed_high_sr_ffhq"
    },
    "config_file": "config/sr_sr3_16_128.json"  # from CLI
}

Logger

In [None]:
# # logging
# torch.backends.cudnn.enabled = True
# torch.backends.cudnn.benchmark = True

# Logger.setup_logger(None, opt['path']['log'],
#                     'train', level=logging.INFO, screen=True)
# Logger.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
# logger = logging.getLogger('base')
# logger.info(Logger.dict2str(opt))
# tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])

# # Initialize WandbLogger
# if opt['enable_wandb']:
#     import wandb
#     wandb_logger = WandbLogger(opt)
#     wandb.define_metric('validation/val_step')
#     wandb.define_metric('epoch')
#     wandb.define_metric("validation/*", step_metric="val_step")
#     val_step = 0
# else:
#     wandb_logger = None

Dataset creation

In [None]:
class LRHRDataset(Dataset):
    def __init__(self, dataroot, datatype, l_resolution=16, r_resolution=128, split='train', data_len=-1, need_LR=False):
        self.datatype = datatype
        self.l_res = l_resolution
        self.r_res = r_resolution
        self.data_len = data_len
        self.need_LR = need_LR
        self.split = split

        if datatype == 'lmdb':
            self.env = lmdb.open(dataroot, readonly=True, lock=False,
                                 readahead=False, meminit=False)
            # init the datalen
            with self.env.begin(write=False) as txn:
                self.dataset_len = int(txn.get("length".encode("utf-8")))
            if self.data_len <= 0:
                self.data_len = self.dataset_len
            else:
                self.data_len = min(self.data_len, self.dataset_len)
        elif datatype == 'img':
            self.sr_path = Util.get_paths_from_images(
                '{}/sr_{}_{}'.format(dataroot, l_resolution, r_resolution))
            self.hr_path = Util.get_paths_from_images(
                '{}/hr_{}'.format(dataroot, r_resolution))
            if self.need_LR:
                self.lr_path = Util.get_paths_from_images(
                    '{}/lr_{}'.format(dataroot, l_resolution))
            self.dataset_len = len(self.hr_path)
            if self.data_len <= 0:
                self.data_len = self.dataset_len
            else:
                self.data_len = min(self.data_len, self.dataset_len)
        else:
            raise NotImplementedError(
                'data_type [{:s}] is not recognized.'.format(datatype))

    def __len__(self):
        return self.data_len

    def __getitem__(self, index):
        img_HR = None
        img_LR = None

        if self.datatype == 'lmdb':
            with self.env.begin(write=False) as txn:
                hr_img_bytes = txn.get(
                    'hr_{}_{}'.format(
                        self.r_res, str(index).zfill(5)).encode('utf-8')
                )
                sr_img_bytes = txn.get(
                    'sr_{}_{}_{}'.format(
                        self.l_res, self.r_res, str(index).zfill(5)).encode('utf-8')
                )
                if self.need_LR:
                    lr_img_bytes = txn.get(
                        'lr_{}_{}'.format(
                            self.l_res, str(index).zfill(5)).encode('utf-8')
                    )
                # skip the invalid index
                while (hr_img_bytes is None) or (sr_img_bytes is None):
                    new_index = random.randint(0, self.data_len-1)
                    hr_img_bytes = txn.get(
                        'hr_{}_{}'.format(
                            self.r_res, str(new_index).zfill(5)).encode('utf-8')
                    )
                    sr_img_bytes = txn.get(
                        'sr_{}_{}_{}'.format(
                            self.l_res, self.r_res, str(new_index).zfill(5)).encode('utf-8')
                    )
                    if self.need_LR:
                        lr_img_bytes = txn.get(
                            'lr_{}_{}'.format(
                                self.l_res, str(new_index).zfill(5)).encode('utf-8')
                        )
                img_HR = Image.open(BytesIO(hr_img_bytes)).convert("RGB")
                img_SR = Image.open(BytesIO(sr_img_bytes)).convert("RGB")
                if self.need_LR:
                    img_LR = Image.open(BytesIO(lr_img_bytes)).convert("RGB")
        else:
            img_HR = Image.open(self.hr_path[index]).convert("RGB")
            img_SR = Image.open(self.sr_path[index]).convert("RGB")
            if self.need_LR:
                img_LR = Image.open(self.lr_path[index]).convert("RGB")
        if self.need_LR:
            [img_LR, img_SR, img_HR] = Util.transform_augment(
                [img_LR, img_SR, img_HR], split=self.split, min_max=(-1, 1))
            return {'LR': img_LR, 'HR': img_HR, 'SR': img_SR, 'Index': index}
        else:
            [img_SR, img_HR] = Util.transform_augment(
                [img_SR, img_HR], split=self.split, min_max=(-1, 1))
            return {'HR': img_HR, 'SR': img_SR, 'Index': index}

In [None]:
# dataset
for phase, dataset_opt in opt['datasets'].items():
    if phase == 'train' and opt['phase'] != 'val':
        train_set = Data.create_dataset(dataset_opt, phase)
        train_loader = Data.create_dataloader(
            train_set, dataset_opt, phase)
    elif phase == 'val':
        val_set = Data.create_dataset(dataset_opt, phase)
        val_loader = Data.create_dataloader(
            val_set, dataset_opt, phase)
# logger.info('Initial Dataset Finished')

Model loading

In [None]:
# model
diffusion = Model.create_model(opt)
# logger.info('Initial Model Finished')

In [None]:
# Train
current_step = diffusion.begin_step
current_epoch = diffusion.begin_epoch
n_iter = opt['train']['n_iter']

if opt['path']['resume_state']:
    logger.info('Resuming training from epoch: {}, iter: {}.'.format(
        current_epoch, current_step))

diffusion.set_new_noise_schedule(
    opt['model']['beta_schedule'][opt['phase']], schedule_phase=opt['phase'])
if opt['phase'] == 'train':
    while current_step < n_iter:
        current_epoch += 1
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > n_iter:
                break
            diffusion.feed_data(train_data)
            diffusion.optimize_parameters()
            # log
            if current_step % opt['train']['print_freq'] == 0:
                logs = diffusion.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}> '.format(
                    current_epoch, current_step)
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

                if wandb_logger:
                    wandb_logger.log_metrics(logs)

    #         # validation
    #         if current_step % opt['train']['val_freq'] == 0:
    #             avg_psnr = 0.0
    #             idx = 0
    #             result_path = '{}/{}'.format(opt['path']
    #                                             ['results'], current_epoch)
    #             os.makedirs(result_path, exist_ok=True)

    #             diffusion.set_new_noise_schedule(
    #                 opt['model']['beta_schedule']['val'], schedule_phase='val')
    #             for _,  val_data in enumerate(val_loader):
    #                 idx += 1
    #                 diffusion.feed_data(val_data)
    #                 diffusion.test(continous=False)
    #                 visuals = diffusion.get_current_visuals()
    #                 sr_img = Metrics.tensor2img(visuals['SR'])  # uint8
    #                 hr_img = Metrics.tensor2img(visuals['HR'])  # uint8
    #                 lr_img = Metrics.tensor2img(visuals['LR'])  # uint8
    #                 fake_img = Metrics.tensor2img(visuals['INF'])  # uint8

    #                 # generation
    #                 Metrics.save_img(
    #                     hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx))
    #                 Metrics.save_img(
    #                     sr_img, '{}/{}_{}_sr.png'.format(result_path, current_step, idx))
    #                 Metrics.save_img(
    #                     lr_img, '{}/{}_{}_lr.png'.format(result_path, current_step, idx))
    #                 Metrics.save_img(
    #                     fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx))
    #                 tb_logger.add_image(
    #                     'Iter_{}'.format(current_step),
    #                     np.transpose(np.concatenate(
    #                         (fake_img, sr_img, hr_img), axis=1), [2, 0, 1]),
    #                     idx)
    #                 avg_psnr += Metrics.calculate_psnr(
    #                     sr_img, hr_img)

    #                 if wandb_logger:
    #                     wandb_logger.log_image(
    #                         f'validation_{idx}', 
    #                         np.concatenate((fake_img, sr_img, hr_img), axis=1)
    #                     )

    #             avg_psnr = avg_psnr / idx
    #             diffusion.set_new_noise_schedule(
    #                 opt['model']['beta_schedule']['train'], schedule_phase='train')
    #             # log
    #             logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
    #             logger_val = logging.getLogger('val')  # validation logger
    #             logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
    #                 current_epoch, current_step, avg_psnr))
    #             # tensorboard logger
    #             tb_logger.add_scalar('psnr', avg_psnr, current_step)

    #             if wandb_logger:
    #                 wandb_logger.log_metrics({
    #                     'validation/val_psnr': avg_psnr,
    #                     'validation/val_step': val_step
    #                 })
    #                 val_step += 1

    #         if current_step % opt['train']['save_checkpoint_freq'] == 0:
    #             logger.info('Saving models and training states.')
    #             diffusion.save_network(current_epoch, current_step)

    #             if wandb_logger and opt['log_wandb_ckpt']:
    #                 wandb_logger.log_checkpoint(current_epoch, current_step)

    #     if wandb_logger:
    #         wandb_logger.log_metrics({'epoch': current_epoch-1})

    # # save model
    # logger.info('End of training.')


In [None]:
# handling validation (not considered for now and maybe even moving to seperate notebook could make sense, has to be thought about)

# else:
#     logger.info('Begin Model Evaluation.')
#     avg_psnr = 0.0
#     avg_ssim = 0.0
#     idx = 0
#     result_path = '{}'.format(opt['path']['results'])
#     os.makedirs(result_path, exist_ok=True)
#     for _,  val_data in enumerate(val_loader):
#         idx += 1
#         diffusion.feed_data(val_data)
#         diffusion.test(continous=True)
#         visuals = diffusion.get_current_visuals()

#         hr_img = Metrics.tensor2img(visuals['HR'])  # uint8
#         lr_img = Metrics.tensor2img(visuals['LR'])  # uint8
#         fake_img = Metrics.tensor2img(visuals['INF'])  # uint8

#         sr_img_mode = 'grid'
#         if sr_img_mode == 'single':
#             # single img series
#             sr_img = visuals['SR']  # uint8
#             sample_num = sr_img.shape[0]
#             for iter in range(0, sample_num):
#                 Metrics.save_img(
#                     Metrics.tensor2img(sr_img[iter]), '{}/{}_{}_sr_{}.png'.format(result_path, current_step, idx, iter))
#         else:
#             # grid img
#             sr_img = Metrics.tensor2img(visuals['SR'])  # uint8
#             Metrics.save_img(
#                 sr_img, '{}/{}_{}_sr_process.png'.format(result_path, current_step, idx))
#             Metrics.save_img(
#                 Metrics.tensor2img(visuals['SR'][-1]), '{}/{}_{}_sr.png'.format(result_path, current_step, idx))

#         Metrics.save_img(
#             hr_img, '{}/{}_{}_hr.png'.format(result_path, current_step, idx))
#         Metrics.save_img(
#             lr_img, '{}/{}_{}_lr.png'.format(result_path, current_step, idx))
#         Metrics.save_img(
#             fake_img, '{}/{}_{}_inf.png'.format(result_path, current_step, idx))

#         # generation
#         eval_psnr = Metrics.calculate_psnr(Metrics.tensor2img(visuals['SR'][-1]), hr_img)
#         eval_ssim = Metrics.calculate_ssim(Metrics.tensor2img(visuals['SR'][-1]), hr_img)

#         avg_psnr += eval_psnr
#         avg_ssim += eval_ssim

#         if wandb_logger and opt['log_eval']:
#             wandb_logger.log_eval_data(fake_img, Metrics.tensor2img(visuals['SR'][-1]), hr_img, eval_psnr, eval_ssim)

#     avg_psnr = avg_psnr / idx
#     avg_ssim = avg_ssim / idx

#     # log
#     logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
#     logger.info('# Validation # SSIM: {:.4e}'.format(avg_ssim))
#     logger_val = logging.getLogger('val')  # validation logger
#     logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}, ssim：{:.4e}'.format(
#         current_epoch, current_step, avg_psnr, avg_ssim))

#     if wandb_logger:
#         if opt['log_eval']:
#             wandb_logger.log_eval_table()
#         wandb_logger.log_metrics({
#             'PSNR': float(avg_psnr),
#             'SSIM': float(avg_ssim)
#         })