In [71]:
%cd ../train

/home/jupyter/work/resources/SR-Gaming-Bench/train
/home/jupyter/work/resources/SR-Gaming-Bench/train


In [84]:
import datetime
import logging
import os
import random
import time
from pathlib import Path
from typing import Any

import torch
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model
from basicsr.train import (
    create_train_val_dataloader,
    init_tb_loggers,
    load_resume_state,
)
from basicsr.utils import (
    AvgTimer,
    MessageLogger,
    get_env_info,
    get_root_logger,
    get_time_str,
    make_exp_dirs,
    mkdir_and_rename,
    set_random_seed,
)
from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import (
    _postprocess_yml_value,
    copy_opt_file,
    dict2str,
    yaml_load,
)

In [85]:
def parse_options(
    root_path: str,
    opt_path: str,
    is_train: bool = True,
    options_in_root: bool = False,
    launcher: str = None,
    auto_resume: bool = False,
    debug: bool = False,
    force_yml: list[str] = None,
) -> dict[str, Any]:
    # TODO написать docstring
    if options_in_root:
        opt = yaml_load(os.path.join(root_path, opt_path))
    else:
        opt = yaml_load(opt_path)

    if launcher is None:
        opt["dist"] = False
        print("Disable distributed.", flush=True)
    else:
        opt["dist"] = True
        if launcher == "slurm" and "dist_params" in opt:
            init_dist(launcher, **opt["dist_params"])
        else:
            init_dist(launcher)
    opt["rank"], opt["world_size"] = get_dist_info()

    seed = opt.get("manual_seed")
    if seed is None:
        seed = random.randint(1, 10000)
        opt["manual_seed"] = seed
    set_random_seed(seed + opt["rank"])

    if force_yml is not None:
        for entry in force_yml:
            keys, value = entry.split("=")
            keys, value = keys.strip(), value.strip()
            value = _postprocess_yml_value(value)
            eval_str = "opt"
            for key in keys.split(":"):
                eval_str += f'["{key}"]'
            eval_str += "=value"
            exec(eval_str)

    opt["auto_resume"] = auto_resume
    opt["is_train"] = is_train

    if debug and not opt["name"].startswith("debug"):
        opt["name"] = "debug_" + opt["name"]

    if opt["num_gpu"] == "auto":
        opt["num_gpu"] = torch.cuda.device_count()

    datasets = opt["datasets"][opt["datasets"]["type"]]
    for phase, dataset in datasets.items():
        phase = phase.split("_")[0]
        dataset["phase"] = phase
        if "scale" in opt:
            dataset["scale"] = opt["scale"]
        if dataset.get("dataroot_gt") is not None:
            dataset["dataroot_gt"] = os.path.expanduser(dataset["dataroot_gt"])
        if dataset.get("dataroot_lq") is not None:
            dataset["dataroot_lq"] = os.path.expanduser(dataset["dataroot_lq"])

    for key, val in opt["path"].items():
        if (val is not None) and ("resume_state" in key or "pretrain_network" in key):
            opt["path"][key] = os.path.expanduser(val)

    if is_train:
        experiments_root = opt["path"].get("experiments_root")
        if experiments_root is None:
            experiments_root = os.path.join(root_path, "experiments")
        experiments_root = os.path.join(experiments_root, opt["name"])

        opt["path"]["experiments_root"] = experiments_root
        opt["path"]["models"] = os.path.join(experiments_root, "models")
        opt["path"]["training_states"] = os.path.join(
            experiments_root, "training_states"
        )
        opt["path"]["log"] = experiments_root
        opt["path"]["visualization"] = os.path.join(experiments_root, "visualization")

        if "debug" in opt["name"]:
            if "val" in opt:
                opt["val"]["val_freq"] = 8
            opt["logger"]["print_freq"] = 1
            opt["logger"]["save_checkpoint_freq"] = 8
    else:
        results_root = opt["path"].get("results_root")
        if results_root is None:
            results_root = os.path.join(root_path, "results")
        results_root = os.path.join(results_root, opt["name"])

        opt["path"]["results_root"] = results_root
        opt["path"]["log"] = results_root
        opt["path"]["visualization"] = os.path.join(results_root, "visualization")

    return opt

In [86]:
root = Path.cwd().parents[0]
exp_folder = root / "experiments"
os.makedirs(exp_folder, exist_ok=True)

opt_path = str(root / "configs/train/finetune_realesrgan_x4plus_game_engine.yaml")
opt = parse_options(root, opt_path=opt_path, is_train=True, options_in_root=True)
opt["root_path"] = root

Disable distributed.


In [87]:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

In [88]:
resume_state = load_resume_state(opt)

if resume_state is None:
    make_exp_dirs(opt)
    if (
        opt["logger"].get("use_tb_logger")
        and "debug" not in opt["name"]
        and opt["rank"] == 0
    ):
        mkdir_and_rename(os.path.join(exp_folder, "tb_logger", opt["name"]))

copy_opt_file(root / opt_path, opt["path"]["experiments_root"])

In [89]:
log_file = os.path.join(opt["path"]["log"], f"train_{opt['name']}_{get_time_str()}.log")

logger = get_root_logger(
    logger_name="basicsr", log_level=logging.INFO, log_file=log_file
)
logger.info(get_env_info())
logger.info(dict2str(opt))

tb_logger = init_tb_loggers(opt)

2024-02-22 17:09:42,815 INFO: 
                ____                _       _____  ____
               / __ ) ____ _ _____ (_)_____/ ___/ / __ \
              / __  |/ __ `// ___// // ___/\__ \ / /_/ /
             / /_/ // /_/ /(__  )/ // /__ ___/ // _, _/
            /_____/ \__,_//____//_/ \___//____//_/ |_|
     ______                   __   __                 __      __
    / ____/____   ____   ____/ /  / /   __  __ _____ / /__   / /
   / / __ / __ \ / __ \ / __  /  / /   / / / // ___// //_/  / /
  / /_/ // /_/ // /_/ // /_/ /  / /___/ /_/ // /__ / /<    /_/
  \____/ \____/ \____/ \____/  /_____/\____/ \___//_/|_|  (_)
    
Version Information: 
	BasicSR: 1.4.2
	PyTorch: 2.1.2+cu121
	TorchVision: 0.16.2+cu121
2024-02-22 17:09:42,817 INFO: 
  name: finetune_RealESRGANx4plus_GameEngineData
  model_type: RealESRGANModel
  scale: 4
  num_gpu: 1
  manual_seed: 0
  l1_gt_usm: True
  percep_gt_usm: True
  gan_gt_usm: False
  high_order_degradation: False
  datasets:[
    type: files
    h

In [90]:
datasets_type = opt["datasets"]["type"]
if datasets_type == "files":
    result = create_train_val_dataloader(opt, logger, datasets_type, root=root)
    train_loader, train_sampler, val_loaders, total_epochs, total_iters = result

Name RealESRGANPairedDataset is not found, use name: RealESRGANPairedDataset_basicsr!


2024-02-22 17:09:45,467 INFO: Dataset [RealESRGANPairedDataset] - GameEngine_train is built.
2024-02-22 17:09:45,469 INFO: Training statistics:
	Number of train images: 14431
	Dataset enlarge ratio: 1
	Batch size per gpu: 12
	World size (gpu number): 1
	Require iter number per epoch: 1203
	Total epochs: 84; iters: 100000.
2024-02-22 17:09:45,712 INFO: Dataset [PairedImageDataset] - GameEngine_val is built.
2024-02-22 17:09:45,713 INFO: Number of val images/folders in GameEngine_val: 3600


In [91]:
model = build_model(opt, root=root)
if resume_state:
    model.resume_training(resume_state)
    logger.info(
        f"Resuming training from epoch: {resume_state['epoch']},"
        f"iter: {resume_state['iter']}."
    )
    start_epoch = resume_state["epoch"]
    current_iter = resume_state["iter"]
else:
    start_epoch = 0
    current_iter = 0

Name RealESRGANModel is not found, use name: RealESRGANModel_basicsr!


2024-02-22 17:09:46,772 INFO: Network [RRDBNet] is created.
2024-02-22 17:09:46,832 INFO: Network: RRDBNet, with parameters: 16,697,987
2024-02-22 17:09:46,833 INFO: RRDBNet(
  (conv_first): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (body): Sequential(
    (0): RRDB(
      (rdb1): ResidualDenseBlock(
        (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv3): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv4): Conv2d(160, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv5): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (rdb2): ResidualDenseBlock(
        (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), pa

Name UNetDiscriminatorSN is not found, use name: UNetDiscriminatorSN_basicsr!


2024-02-22 17:09:47,781 INFO: Network [UNetDiscriminatorSN] is created.
2024-02-22 17:09:47,787 INFO: Network: UNetDiscriminatorSN, with parameters: 4,376,897
2024-02-22 17:09:47,788 INFO: UNetDiscriminatorSN(
  (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv4): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv5): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

In [92]:
msg_logger = MessageLogger(opt, current_iter, tb_logger)

In [93]:
prefetch_mode = opt["datasets"][datasets_type]["train"].get("prefetch_mode")
if prefetch_mode is None or prefetch_mode == "cpu":
    prefetcher = CPUPrefetcher(train_loader)
elif prefetch_mode == "cuda":
    prefetcher = CUDAPrefetcher(train_loader, opt)
    logger.info(f"Use {prefetch_mode} prefetch dataloader")
    if opt["datasets"][datasets_type]["train"].get("pin_memory") is not True:
        raise ValueError("Please set pin_memory=True for CUDAPrefetcher.")
else:
    raise ValueError(
        f"Wrong prefetch_mode {prefetch_mode}. "
        f"Supported ones are: None, 'cuda', 'cpu'."
    )

In [None]:
logger.info(f"Start training from epoch: {start_epoch}, iter: {current_iter}")
data_timer, iter_timer = AvgTimer(), AvgTimer()
start_time = time.time()

for epoch in range(start_epoch, total_epochs + 1):
    train_sampler.set_epoch(epoch)
    prefetcher.reset()
    train_data = prefetcher.next()

    while train_data is not None:
        data_timer.record()

        current_iter += 1
        if current_iter > total_iters:
            break

        model.update_learning_rate(
            current_iter, warmup_iter=opt["train"].get("warmup_iter", -1)
        )

        model.feed_data(train_data)
        model.optimize_parameters(current_iter)
        iter_timer.record()
        if current_iter == 1:
            msg_logger.reset_start_time()
        if current_iter % opt["logger"]["print_freq"] == 0:
            log_vars = {"epoch": epoch, "iter": current_iter}
            log_vars.update({"lrs": model.get_current_learning_rate()})
            log_vars.update(
                {
                    "time": iter_timer.get_avg_time(),
                    "data_time": data_timer.get_avg_time(),
                }
            )
            log_vars.update(model.get_current_log())
            msg_logger(log_vars)

        if current_iter % opt["logger"]["save_checkpoint_freq"] == 0:
            logger.info("Saving models and training states.")
            model.save(epoch, current_iter)

        if opt.get("val") is not None and (current_iter % opt["val"]["val_freq"] == 0):
            if len(val_loaders) > 1:
                logger.warning(
                    "Multiple validation datasets are *only* supported by SRModel."
                )
            for val_loader in val_loaders:
                model.validation(
                    val_loader,
                    current_iter,
                    tb_logger,
                    opt["val"]["save_img"],
                    use_first_n_batches=opt["val"]["use_first_n_batches"],
                )

        data_timer.start()
        iter_timer.start()
        train_data = prefetcher.next()

consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
logger.info(f"End of training. Time consumed: {consumed_time}")
logger.info("Save the latest model.")

model.save(epoch=-1, current_iter=-1)  # -1 stands for the latest
if opt.get("val") is not None:
    for val_loader in val_loaders:
        model.validation(val_loader, current_iter, tb_logger, opt["val"]["save_img"])

if tb_logger:
    tb_logger.close()