In [None]:
import numpy as np

from os import listdir, remove

import jax.numpy as jnp
from scico import functional, linop, loss, metric
from scico.optimize import PDHG

import optuna

from src.forward_operator.forward_operator import forward_operator
from src.forward_operator.operators import *

from src.inversions.baseline_method.inversion_baseline import *

from src.input_initialization import initialize_input

In [None]:
CFA = 'sony'
BINNING = CFA == 'quad_bayer'
NOISE_LEVEL = 5

INPUT_DIR = 'input/ms/'

In [None]:
res = []

for i, image_name in enumerate(listdir(INPUT_DIR)):
    x, spectral_stencil = initialize_input(INPUT_DIR + image_name)

    cfa_op = cfa_operator(CFA, x.shape, spectral_stencil, 'dirac')
    forward_op = forward_operator([cfa_op])
    baseline_inverse = Inverse_problem(CFA, BINNING, 0, x.shape, spectral_stencil, 'dirac')

    def forward_pass(x):
        return jnp.array(forward_op.direct(x))

    def adjoint_pass(y):
        return jnp.array(forward_op.adjoint(y))

    A = linop.LinearOperator(input_shape=x.shape, output_shape=x.shape[:-1], eval_fn=forward_pass, adj_fn=adjoint_pass)

    C = linop.FiniteDifference(input_shape=x.shape, append=0, axes=(0, 1))
    C_squared_norm = np.float64(linop.operator_norm(C))**2

    y = np.clip(forward_op.direct(x) + np.random.normal(0, NOISE_LEVEL / 100, forward_op.output_shape), 0, 1)

    f = loss.SquaredL2Loss(y=jnp.array(y), A=A)

    x_baseline = jnp.array(baseline_inverse(y)[:, :, ::-1])
    
    def objective(trial):
        lambd = trial.suggest_float('lambd', 1e-3, 0.2, log=True)
        sigma = trial.suggest_float('sigma', 1e-2, 100, log=True)
        tmp = trial.suggest_float('tmp', 1e-3, 1)
        tau = tmp / (sigma * C_squared_norm)

        g = lambd * functional.L21Norm(l2_axis=(0, 3))

        solver_TV = PDHG(
            f=f,
            g=g,
            C=C,
            tau=tau,
            sigma=sigma,
            x0=x_baseline,
            maxiter=400
        )

        return metric.mse(x, solver_TV.solve())

    study = optuna.create_study(direction='minimize', storage=f'sqlite:///{image_name}.sqlite3', study_name='tv', load_if_exists=True)
    study.optimize(objective, n_trials=10)
    best_trial = study.best_trial

    res.append(best_trial.value)

    print(i, np.mean(res))
    print('----------------------------------------------------------------')

In [None]:
print(res)
print(np.mean(res), np.std(res))