In [45]:
import deepinv as dinv
from torch.utils.data import DataLoader
import torch
from pathlib import Path
from torchvision import transforms
from deepinv.optim.prior import PnP
from deepinv.models.denoiser import Denoiser
from deepinv.utils.demo import load_dataset, load_degradation
from deepinv.training_utils import train, test
from deepinv.models.denoiser import online_weights_path

In [46]:
n_channels = 2  # real + imaginary parts

# Set up the trainable denoising prior
denoiser_spec = {
    "name": "dncnn",
    "args": {
        "in_channels": n_channels,
        "out_channels": n_channels,
        "depth": 7,
        "pretrained": None,
        "train": True,
        "device": 'cpu',
    },
}

# If the prior dict value is initialized with a table of length max_iter, then a distinct model is trained for each
# iteration. For fixed trained model prior across iterations, initialize with a single model.
prior = PnP(denoiser=Denoiser(denoiser_spec))

In [47]:
data_fidelity = dinv.optim.L2()

max_iter = 3  # number of unfolded layers
lamb = [1.0] * max_iter  # initialization of the regularization parameter
stepsize = [1.0] * max_iter  # initialization of the step sizes.
sigma_denoiser = [0.01] * max_iter  # initialization of the denoiser parameters
params_algo = {  # wrap all the restoration parameters in a 'params_algo' dictionary
    "stepsize": stepsize,
    "g_param": sigma_denoiser,
    "lambda": lamb,
}

trainable_params = [
    "lambda",
    "stepsize",
    "g_param",
]  # define which parameters from 'params_algo' are trainable

# Define the unfolded trainable model.
model = dinv.unfolded.Unfolded(
    "HQS",
    params_algo=params_algo,
    trainable_params=trainable_params,
    data_fidelity=data_fidelity,
    max_iter=max_iter,
    prior=prior,
)

In [48]:
url = online_weights_path() + "demo_ei_ckp_150.pth"
ckpt = torch.hub.load_state_dict_from_url(
    url, map_location=lambda storage, loc: storage, file_name="demo_ei_ckp_150.pth"
)
ckpt = torch.load_state_dict('new_demo_ei_ckp_150.ckpt')
# load a checkpoint to reduce training time
model.load_state_dict(ckpt["state_dict"])

RuntimeError: Error(s) in loading state_dict for BaseUnfold:
	Missing key(s) in state_dict: "init_params_algo.g_param.0", "init_params_algo.g_param.1", "init_params_algo.g_param.2", "init_params_algo.lambda.0", "init_params_algo.lambda.1", "init_params_algo.lambda.2", "init_params_algo.stepsize.0", "init_params_algo.stepsize.1", "init_params_algo.stepsize.2", "prior.0.denoiser.denoiser.in_conv.weight", "prior.0.denoiser.denoiser.in_conv.bias", "prior.0.denoiser.denoiser.conv_list.0.weight", "prior.0.denoiser.denoiser.conv_list.0.bias", "prior.0.denoiser.denoiser.conv_list.1.weight", "prior.0.denoiser.denoiser.conv_list.1.bias", "prior.0.denoiser.denoiser.conv_list.2.weight", "prior.0.denoiser.denoiser.conv_list.2.bias", "prior.0.denoiser.denoiser.conv_list.3.weight", "prior.0.denoiser.denoiser.conv_list.3.bias", "prior.0.denoiser.denoiser.conv_list.4.weight", "prior.0.denoiser.denoiser.conv_list.4.bias", "prior.0.denoiser.denoiser.out_conv.weight", "prior.0.denoiser.denoiser.out_conv.bias". 
	Unexpected key(s) in state_dict: "params_algo.g_param.0", "params_algo.g_param.1", "params_algo.g_param.2", "params_algo.lambda.0", "params_algo.lambda.1", "params_algo.lambda.2", "params_algo.stepsize.0", "params_algo.stepsize.1", "params_algo.stepsize.2", "prior.prox_g.0.denoiser.in_conv.weight", "prior.prox_g.0.denoiser.in_conv.bias", "prior.prox_g.0.denoiser.conv_list.0.weight", "prior.prox_g.0.denoiser.conv_list.0.bias", "prior.prox_g.0.denoiser.conv_list.1.weight", "prior.prox_g.0.denoiser.conv_list.1.bias", "prior.prox_g.0.denoiser.conv_list.2.weight", "prior.prox_g.0.denoiser.conv_list.2.bias", "prior.prox_g.0.denoiser.conv_list.3.weight", "prior.prox_g.0.denoiser.conv_list.3.bias", "prior.prox_g.0.denoiser.conv_list.4.weight", "prior.prox_g.0.denoiser.conv_list.4.bias", "prior.prox_g.0.denoiser.out_conv.weight", "prior.prox_g.0.denoiser.out_conv.bias". 

In [49]:
new_weights = dict((key.replace('params_algo', 'init_params_algo'), value) for (key, value) in ckpt['state_dict'].items())
new_dict = ckpt
new_dict['state_dict'] = new_weights

new_weights = dict((key.replace('denoiser.', 'denoiser.denoiser.'), value) for (key, value) in new_weights.items())
new_dict = ckpt
new_dict['state_dict'] = new_weights

new_weights = dict((key.replace('prior.prox_g.', 'prior.'), value) for (key, value) in new_weights.items())
new_dict = ckpt
new_dict['state_dict'] = new_weights

model.load_state_dict(new_dict["state_dict"])
torch.save(new_dict, 'new_demo_ei_ckp_150.ckpt')

In [33]:
print(new_dict['state_dict'].keys())

dict_keys(['init_init_params_algo.g_param.0', 'init_init_params_algo.g_param.1', 'init_init_params_algo.g_param.2', 'init_init_params_algo.lambda.0', 'init_init_params_algo.lambda.1', 'init_init_params_algo.lambda.2', 'init_init_params_algo.stepsize.0', 'init_init_params_algo.stepsize.1', 'init_init_params_algo.stepsize.2', 'prior.prox_g.0.denoiser.denoiser.in_conv.weight', 'prior.prox_g.0.denoiser.denoiser.in_conv.bias', 'prior.prox_g.0.denoiser.denoiser.conv_list.0.weight', 'prior.prox_g.0.denoiser.denoiser.conv_list.0.bias', 'prior.prox_g.0.denoiser.denoiser.conv_list.1.weight', 'prior.prox_g.0.denoiser.denoiser.conv_list.1.bias', 'prior.prox_g.0.denoiser.denoiser.conv_list.2.weight', 'prior.prox_g.0.denoiser.denoiser.conv_list.2.bias', 'prior.prox_g.0.denoiser.denoiser.conv_list.3.weight', 'prior.prox_g.0.denoiser.denoiser.conv_list.3.bias', 'prior.prox_g.0.denoiser.denoiser.conv_list.4.weight', 'prior.prox_g.0.denoiser.denoiser.conv_list.4.bias', 'prior.prox_g.0.denoiser.denoiser.