### Low rank approximations simulation

In [33]:
import numpy as np
from scipy.optimize import minimize, nnls, LinearConstraint
import scipy.stats as stats
import matplotlib.pyplot as plt
import seaborn as sns

In [34]:
num_cells = 500
num_genes = 100
pack_ratio = 0.25
rank = 10
num_markers = int(pack_ratio*num_genes)

In [35]:

U_real = np.random.randn(num_cells, rank)*2+10
V_real = np.random.randn(num_genes, rank)*3+15
M = U_real@V_real.T
print('M', M.shape)

M (500, 100)


In [36]:
P = np.random.choice([0, 1], size=(num_genes, num_markers), p=[3./4, 1./4])

In [37]:
D = M@P

### Nuclear norm minimization

In [20]:
import cvxpy as cp

In [21]:
delta = 1e-7*D.shape[0]*D.shape[1]

In [22]:
X = cp.Variable((num_cells, num_genes))
cost = cp.norm(X, 'nuc')
constraints = [
    cp.norm(X@P-D, 'fro') <= delta,
    cp.min(X)>=0]

In [23]:
prob = cp.Problem(cp.Minimize(cost), constraints)

In [24]:
prob.solve()

Failure:Interrupted


SolverError: Solver 'SCS' failed. Try another solver, or solve with verbose=True for more information.

In [None]:
corrs = []
for col in range(M.shape[1]):
    corrs.append(stats.pearsonr(M[:, col], X.value[:, col])[0])

In [None]:
plt.hist(corrs, bins=10)

### Alternating least squares

In [38]:
U_current = np.random.randn(num_cells,rank)+10
V_current = np.random.randn(num_genes,rank)+10

In [39]:
import cvxpy as cp

In [40]:
# V_cvx = cp.Variable(V_current.shape)
# U_cvx = cp.Parameter(U_current.shape, nonneg=True)
# D_cvx = cp.Parameter((num_cells, 1), nonneg=True)
# P_cvx = cp.Parameter((num_genes, 1), nonneg=True)
# cost = cp.norm(U_cvx@V_cvx.T@P-D_cvx, 'fro')
# constraints = [cp.min(V_cvx) >= 0]

# prob = cp.Problem(cp.Minimize(cost), constraints)

# def solve_u(U_current, V_current):
    
#     A = P.T@V_current
#     return np.apply_along_axis(lambda row: nnls(A, row)[0], axis=1, arr=D)

# def solve_v(U_current, V_current):
    
#     to_return = 0
    
#     for marker_idx in range(num_markers):
    
#         U_cvx.value = U_current
#         D_cvx.value = D[:, [marker_idx]]
#         P_cvx.value = P[:,[marker_idx]]
#         V_cvx.value = V_current
#         final_cost = prob.solve()
        
#         to_return += V_cvx.value
    
#     return to_return/num_cells

In [41]:
def solve_u(U_current, V_current):
    
    A = P.T@V_current
    return np.apply_along_axis(lambda row: nnls(A, row)[0], axis=1, arr=D)

def solve_v(U_current, V_current):
    
    V_cvx = cp.Variable(V_current.shape)
    V_cvx.value = V_current
    cost = cp.norm(U_current@V_cvx.T@P-D, 'fro')
    constraints = [cp.min(V_cvx) >= 0]
    
    prob = cp.Problem(cp.Minimize(cost), constraints)
    final_cost = prob.solve()
    
    return V_cvx.value

In [42]:
# nonneg_constraint = LinearConstraint(np.eye(num_genes * rank), 0, np.inf)

# def v_objective(V_vec, U_current):
    
#     V_temp = V_vec.reshape(V_current.shape)
#     return ((U_current@V_temp.T@P-D)**2).sum()
    
# def solve_v_scipy(U_current, V_current):
    
#     V_vec = V_current.reshape(-1)
    
#     res = minimize(lambda v: v_objective(v, U_current), V_vec, constraints=[nonneg_constraint])
#     return res

In [43]:
# nonneg_constraint = LinearConstraint(np.eye(num_genes * rank), 0, np.inf)
# res = minimize(lambda v: v_objective(v, U_current), V_current.reshape(-1))

In [44]:
U_current = solve_u(U_current, V_current)

In [47]:
%%time
# u_error = []
# v_error = []
for i in range(100):

    U_current = solve_u(U_current, V_current)
    V_current = solve_v(U_current, V_current)
    
    break

CPU times: user 33.1 s, sys: 185 ms, total: 33.2 s
Wall time: 33.2 s


In [48]:
U_fitted = U_current
V_fitted = V_current

In [49]:
M_fitted = U_fitted@V_fitted.T

In [50]:
M_fitted

array([[1888.36965032, 1579.56127265, 1357.47041186, ..., 1148.36570779,
        1741.01656454, 1868.91837721],
       [1709.46471285, 1435.6080692 , 1230.17743223, ..., 1043.35674484,
        1561.84636886, 1691.68263052],
       [1726.78618204, 1475.61735837, 1256.82689742, ..., 1051.023396  ,
        1594.30525098, 1716.03426661],
       ...,
       [1750.10905185, 1510.04688125, 1276.96342332, ..., 1062.0042156 ,
        1606.87403172, 1739.00947878],
       [1811.13527638, 1500.00433228, 1300.11472133, ..., 1087.466107  ,
        1657.13745485, 1787.72227462],
       [1619.34838416, 1408.49237188, 1193.38641363, ..., 1003.57209208,
        1505.7692461 , 1605.5761831 ]])

In [51]:
M

array([[1667.1316563 , 1518.25081344, 1486.35190252, ..., 1522.94213173,
        1754.26127179, 1652.67773995],
       [1492.46267097, 1374.18553781, 1377.02670022, ..., 1373.07214768,
        1563.58181863, 1510.3577748 ],
       [1512.24411799, 1382.20640026, 1406.1183873 , ..., 1388.04781659,
        1602.49997459, 1518.2753184 ],
       ...,
       [1547.82280275, 1421.72946454, 1443.31080358, ..., 1416.91085355,
        1613.84053425, 1531.89357292],
       [1591.53135753, 1425.07528657, 1438.65625784, ..., 1413.90735879,
        1661.73632997, 1597.03830962],
       [1461.38409524, 1335.61066796, 1351.5683792 , ..., 1340.27296053,
        1515.99813671, 1438.17084434]])

In [52]:
for col in range(M.shape[1]):
    print(stats.pearsonr(M[:, col], M_fitted[:, col]))

(0.9824819397388962, 0.0)
(0.9883952630144568, 0.0)
(0.962543222434418, 1.952896212664096e-284)
(0.9930994458776241, 0.0)
(0.9869865407936913, 0.0)
(0.9809217082360407, 0.0)
(0.981294979060586, 0.0)
(0.9844265146572148, 0.0)
(0.9906339792861916, 0.0)
(0.9762721845450449, 0.0)
(0.985207641539266, 0.0)
(0.9654067495199166, 7.028190793110504e-293)
(0.9609525830508534, 5.0248668852659265e-280)
(0.9949900921331604, 0.0)
(0.99112435849324, 0.0)
(0.9829906856772949, 0.0)
(0.985349004644607, 0.0)
(0.9812459957025002, 0.0)
(0.9771633214610873, 0.0)
(0.987011879983598, 0.0)
(0.9823946108511313, 0.0)
(0.9818133520414057, 0.0)
(0.9857948802198786, 0.0)
(0.9755998035141474, 0.0)
(0.9669459318847438, 1.0208511035752284e-297)
(0.9845120059038196, 0.0)
(0.9745822016888608, 0.0)
(0.9684171260980252, 1.464245471753772e-302)
(0.9737515147056379, 2.77e-322)
(0.9679691207673762, 4.617269581706101e-301)
(0.9879327205600071, 0.0)
(0.9843953736273147, 0.0)
(0.9801401955756244, 0.0)
(0.9906264560491775, 0.0)
(