In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from scipy.optimize import minimize_scalar, minimize
from time import time
import seaborn as sns
sns.set_style('darkgrid')
sns.set_context('notebook')
import sys
sys.path.append('..')

In [3]:
from osd import Problem
from osd.components import GaussNoise, SmoothSecondDifference, SparseFirstDiffConvex
from osd.utilities import progress
import cvxpy as cvx

In [29]:
import scipy.linalg as spl

def prox2(v, theta, rho, A=None, return_A=True):
    if A is None:
        n = len(v)
        M = np.diff(np.eye(n), axis=0, n=2)
        r = 2 * theta / rho
        A = np.linalg.inv(np.eye(n) + r * M.T.dot(M))
    if not return_A:
        return A.dot(v)
    else:
        return A.dot(v), A
    
def prox2_alt(v, theta, rho, c=None, return_c=True):
    if c is None:
        n = len(v)
        M = np.diff(np.eye(n), axis=0, n=2)
        r = 2 * theta / rho
        ab = np.zeros((3, n))
        A = np.eye(n) + r * M.T.dot(M)
        for i in range(3):
            ab[i] = np.pad(np.diag(A, k=i), (0, i))
        c = spl.cholesky_banded(ab, lower=True)
    if not return_c:
        return spl.cho_solve_banded((c, True), v)
    else:
        return spl.cho_solve_banded((c, True), v), c


In [30]:
v = np.random.randn(100)
np.allclose(prox2(v, theta=1e4, rho=1, return_A=False), prox2_alt(v, theta=1e4, rho=1, return_c=False))

True

In [33]:
v = np.random.randn(5000)

In [34]:
%timeit prox2(v, theta=1e4, rho=1, return_A=False)

5.16 s ± 238 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
%timeit prox2_alt(v, theta=1e4, rho=1, return_c=False)

1.88 s ± 112 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [36]:
x, A = prox2(v, theta=1e4, rho=1, return_A=True)
x, c = prox2_alt(v, theta=1e4, rho=1, return_c=True)

In [37]:
%timeit prox2(v, theta=1e4, rho=1, return_A=False, A=A)

8.87 ms ± 139 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [38]:
%timeit prox2_alt(v, theta=1e4, rho=1, return_c=False, c=c)

116 µs ± 1.77 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
