In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import svt
from bayes_opt import BayesianOptimization

## Example of using singular value thresholding with bayesian optimization

In [8]:
X = np.random.random([5,5])
mask = np.ones_like(X)
nan = [[0,2],[1,2],[3,0],[4,4]]
for idx in nan: mask[idx[0],idx[1]] = 0
mask

array([[1., 1., 0., 1., 1.],
       [1., 1., 0., 1., 1.],
       [1., 1., 1., 1., 1.],
       [0., 1., 1., 1., 1.],
       [1., 1., 1., 1., 0.]])

the bayes_opt packages maximizes by default, so we flip the sign of the norm error

In [3]:
def svt_solve(threshold, eps):
    X_hat = svt.svt_solve(X, mask, threshold=threshold, eps=eps)
    return  - np.linalg.norm(X - X_hat)

there's probably better hyperparamter ranges...

In [4]:
S = np.linalg.svd(X)[1]
pbounds = {"threshold": (S.min(), S.max()),
    "eps": (1e-6, 1e-1)}

In [5]:
optimizer = BayesianOptimization(
    f= svt_solve ,
    pbounds=pbounds,
    verbose=0, 
    random_state=1,
)
optimizer.maximize(
    init_points=2,
    n_iter=100,
)
optimizer.max

{'target': -0.2585887847778216,
 'params': {'eps': 0.09749825464968818, 'threshold': 1.1832446171648794}}

In [6]:
threshold = optimizer.max["params"]["threshold"]
eps = optimizer.max["params"]["eps"]
X_hat = svt.svt_solve(X, mask, threshold=threshold, eps=eps, max_iters=10000)
X_hat

array([[0.77473361, 0.62987193, 0.17941617, 0.11367791, 0.17633671],
       [0.68555944, 0.54680378, 0.21685369, 0.64186838, 0.15563309],
       [0.37784422, 0.59678177, 0.08103329, 0.60760487, 0.65288022],
       [0.66329891, 0.54097945, 0.30680838, 0.82303869, 0.22784574],
       [0.21223024, 0.06426663, 0.34254093, 0.22115931, 0.10481597]])

not too far off!

In [9]:
X

array([[0.27585592, 0.8423638 , 0.57442402, 0.04281608, 0.55469531],
       [0.60353304, 0.33345617, 0.33872183, 0.61957986, 0.18578328],
       [0.37111083, 0.32575947, 0.11195775, 0.65048824, 0.95797305],
       [0.33693962, 0.27662416, 0.74903084, 0.57655028, 0.14362197],
       [0.60742355, 0.55201214, 0.46393487, 0.5142839 , 0.90875592]])