In [25]:
import argparse
import datetime
import logging
import math
import random
import time
import torch
from os import path as osp
import os

from basicsr.data import create_dataloader, create_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import create_model
from basicsr.utils import (MessageLogger, check_resume, get_env_info,
                           get_root_logger, get_time_str, init_tb_logger,
                           init_wandb_logger, make_exp_dirs, mkdir_and_rename,
                           set_random_seed)
from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import dict2str, parse
from copy import deepcopy

In [17]:
def parse_options(is_train=True):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-opt', type=str, default='../options/train/NAFSSR/NAFSSR-T_x4.yml', help='Path to option YAML file.')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)

    parser.add_argument('--input_path', type=str, required=False, help='The path to the input image. For single image inference only.')
    parser.add_argument('--output_path', type=str, required=False, help='The path to the output image. For single image inference only.')
    parser.add_argument('--f', required=False)

    args = parser.parse_args()
    opt = parse(args.opt, is_train=is_train)

    # distributed settings
    if args.launcher == 'none':
        opt['dist'] = False
        print('Disable distributed.', flush=True)
    else:
        opt['dist'] = True
        if args.launcher == 'slurm' and 'dist_params' in opt:
            init_dist(args.launcher, **opt['dist_params'])
        else:
            init_dist(args.launcher)
            print('init dist .. ', args.launcher)

    opt['rank'], opt['world_size'] = get_dist_info()

    # random seed
    seed = opt.get('manual_seed')
    if seed is None:
        seed = random.randint(1, 10000)
        opt['manual_seed'] = seed
    set_random_seed(seed + opt['rank'])

    if args.input_path is not None and args.output_path is not None:
        opt['img_path'] = {
            'input_img': args.input_path,
            'output_img': args.output_path
        }
    opt['path']['models'] = '../test_results'
    return opt


Disable distributed.


In [27]:
def save_network(self, net, net_label, current_iter, param_key='params'):
    """Save networks.

    Args:
        net (nn.Module | list[nn.Module]): Network(s) to be saved.
        net_label (str): Network label.
        current_iter (int): Current iter number.
        param_key (str | list[str]): The parameter key(s) to save network.
            Default: 'params'.
    """
    if current_iter == -1:
        current_iter = 'latest'
    save_filename = f'{net_label}_{current_iter}.pth'
    save_path = os.path.join(self.opt['path']['models'], save_filename)

    net = net if isinstance(net, list) else [net]
    param_key = param_key if isinstance(param_key, list) else [param_key]
    assert len(net) == len(
        param_key), 'The lengths of net and param_key should be the same.'

    save_dict = {}
    for net_, param_key_ in zip(net, param_key):
        net_ = self.get_bare_model(net_)
        state_dict = net_.state_dict()
        for key, param in state_dict.items():
            if key.startswith('module.'):  # remove unnecessary 'module.'
                key = key[7:]
            state_dict[key] = param.cpu().to(torch.float16)
            # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ converted float32 to float16
        save_dict[param_key_] = state_dict

    torch.save(save_dict, save_path)

def load_network(self, net, load_path, strict=True, param_key='params'):
    """Load network.

    Args:
        load_path (str): The path of networks to be loaded.
        net (nn.Module): Network.
        strict (bool): Whether strictly loaded.
        param_key (str): The parameter key of loaded network. If set to
            None, use the root 'path'.
            Default: 'params'.
    """
    net = self.get_bare_model(net)
    load_net = torch.load(
        load_path, map_location=lambda storage, loc: storage)
    if param_key is not None:
        load_net = load_net[param_key]
    print(' load net keys', load_net.keys)
    # remove unnecessary 'module.'
    for k, v in deepcopy(load_net).items():
        # v = v.to(torch.int8)
        if k.startswith('module.'):
            load_net[k[7:]] = v
            load_net.pop(k)
        # print(v.dtype)
    self._print_different_keys_loading(net, load_net, strict)
    net.load_state_dict(load_net, strict=strict)

In [19]:
opt = parse_options(is_train=True)
model = create_model(opt)

2024-02-04 11:05:34,648 INFO: Model [ImageRestorationModel] is created.


.. cosineannealingLR


In [28]:
save_network(model, model.net_g, 'net_g', -1)

In [26]:
load_network(model, model.net_g, '../test_results/net_g_latest.pth')

 load net keys <built-in method keys of collections.OrderedDict object at 0x7fa83d51c440>


In [None]:
import pickle

# Load the model
model = torch.load('model.pt')

# Open a file in binary write mode
with open('model.pkl', 'wb') as f:
    # Serialize the model to the file
    pickle.dump(model, f)

# Close the file
f.close()