In [82]:
%load_ext autoreload
%autoreload 2

import jointNMF
import numpy as np
from scipy import sparse
from scipy import stats
from sklearn.decomposition import NMF
from sklearn.utils.extmath import safe_sparse_dot

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# TODO items
0. take a look at smaller examples to see if it works
1. make everything sparse (use sparse matrix as much as possible)
2. for initialization run 10 times and pick the best initialization
3. put option to make mu use specified

In [83]:
class CustomRandomState(np.random.RandomState):
    def randint(self, k):
        i = np.random.randint(k)
        return i - i % 2

In [84]:
Nsamples=300000
Nfeatures=2000

np.random.seed(12345)
rs = CustomRandomState()
rvs = stats.poisson(10, loc=10).rvs

num_shared_components=10
num_healthy_components=10
num_disease_components=10

In [85]:
Wh = sparse.random(Nsamples, num_healthy_components+num_shared_components, density=0.1, random_state=rs, data_rvs=rvs).tocsr()
Hh = sparse.random(num_healthy_components+num_shared_components, Nfeatures, density=0.1, random_state=rs, data_rvs=rvs).tocsr()

Wd = sparse.random(Nsamples, num_healthy_components+num_shared_components, density=0.1, random_state=rs, data_rvs=rvs).tocsr()
Hd = sparse.random(num_healthy_components+num_shared_components, Nfeatures, density=0.1, random_state=rs, data_rvs=rvs).tocsr()

Wd[:,:num_shared_components] = Wh[:,:num_shared_components]

Xh = safe_sparse_dot(Wh, Hh) #+ np.random.randn(Nsamples, Nfeatures)*0.001
Xd = safe_sparse_dot(Wd, Hd) #+ np.random.randn(Nsamples, Nfeatures)*0.001
Xh = np.abs(Xh)
Xd = np.abs(Xd)

In [86]:
Xh

<300000x2000 sparse matrix of type '<class 'numpy.float64'>'
	with 100697287 stored elements in Compressed Sparse Row format>

In [87]:
model = jointNMF.JointNMF(Xh, Xd, gamma=50, mu=0.1)

In [None]:
model.solve(maxiters=200)

In [67]:
np.std(model.Wh[:,0].todense()), np.mean(model.Wh[:,0].todense()), 
np.std(model.Wd[:,0].todense()), np.mean(model.Wd[:,0].todense()),
np.asarray(model.Wh[:,0].todense()), np.asarray(model.Wd[:,0].todense())

(array([[0.        ],
        [0.        ],
        [0.07082627],
        ...,
        [0.        ],
        [0.        ],
        [0.        ]]), array([[6.48383461e-02],
        [0.00000000e+00],
        [0.00000000e+00],
        ...,
        [0.00000000e+00],
        [9.36470054e-07],
        [0.00000000e+00]]))

In [52]:
for i in range(3):
    for j in range(3):
        print(i, j, np.corrcoef(model.Wh[:,i].todense()+0.01, model.Wd[:,j].todense()+0.01)[0][1])

0 0 nan
0 1 nan
0 2 nan
1 0 nan
1 1 nan
1 2 nan
2 0 nan
2 1 nan
2 2 nan


In [28]:
for i in range(Wh.shape[1]):
    for j in range(Wh.shape[1]):
        print(i, j, np.corrcoef(model.Wh[:,i], Wh[:,j])[0][1])

0 0 -0.23919132301506876
0 1 -0.020329040533452768
0 2 0.05653547517866928
0 3 0.7944151924840428
0 4 0.34248640161577865
0 5 0.06863453136166481
0 6 0.09027510965114953
0 7 0.23728392716716348
0 8 0.19073355731960454
0 9 -0.05380210378835711
0 10 0.12658510989860022
0 11 0.034269583051979774
0 12 0.055832873995750855
0 13 0.06462120263122763
0 14 -0.09170893271200274
1 0 0.714191716555186
1 1 -0.05942276686801498
1 2 0.13462368519402632
1 3 -0.32299343900889466
1 4 -0.139036975170328
1 5 0.13276039454011296
1 6 0.08736278352387411
1 7 0.00895874870308073
1 8 -0.0835679498092433
1 9 0.18635912045313527
1 10 0.052288711366931646
1 11 0.3267588955253579
1 12 0.005933465095598795
1 13 0.28128691765964003
1 14 -0.16507769329646108
2 0 -0.36413998014793797
2 1 -0.07842703341795322
2 2 0.4862088135104401
2 3 -0.22151938323855214
2 4 0.09448924431447528
2 5 -0.028981374427506076
2 6 0.11131173275188662
2 7 -0.46157264008495474
2 8 0.07860933431543633
2 9 0.15294614801585432
2 10 0.38002232876

In [29]:
for i in range(Wd.shape[1]):
    for j in range(Wd.shape[1]):
        print(i, j, np.corrcoef(model.Wd[:,i], model.Wd[:,j])[0][1])

0 0 1.0
0 1 -0.3357900902708881
0 2 0.03543285551238617
0 3 -0.12140019101452496
0 4 -0.01429777980126236
0 5 -0.02818866283581724
0 6 -0.11070088802267132
0 7 -0.0674341603992078
0 8 0.12275439039427045
0 9 0.22054903598544615
0 10 0.3503140655883547
0 11 -0.08394550285640004
0 12 0.10001908722539606
0 13 -0.028116253036382658
0 14 0.017296721392212132
1 0 -0.33579009027088813
1 1 1.0
1 2 -0.0476116822971609
1 3 0.3666066489777668
1 4 0.012256976388125557
1 5 0.14044778341104175
1 6 0.03202509496453798
1 7 -0.2150334029041851
1 8 -0.28867119649509165
1 9 -0.08520694770439172
1 10 -0.09930831946244728
1 11 0.10049720791944859
1 12 0.03882600021825522
1 13 -0.11239809278739368
1 14 -0.19378275916554474
2 0 0.035432855512386174
2 1 -0.04761168229716091
2 2 0.9999999999999999
2 3 0.13537416688967918
2 4 -0.08860545148313727
2 5 -0.09928990758058849
2 6 -0.08186854089620961
2 7 0.09679817065718208
2 8 0.04292637382169394
2 9 -0.16970467248207585
2 10 -0.12644936005456459
2 11 0.05750750090