In [292]:
import jointNMF
import numpy as np
import scipy
from scipy import sparse as sp
from scipy import stats
from sklearn.decomposition import NMF
from sklearn.utils.extmath import safe_sparse_dot

In [293]:
class CustomRandomState(np.random.RandomState):
    def randint(self, k):
        i = np.random.randint(k)
        return i - i % 2

In [294]:
Nsamples=6000
Nfeatures=2000

np.random.seed(12345)
rs = CustomRandomState()
rvs = stats.poisson(10, loc=3).rvs

num_shared_components=5
num_healthy_components=10
num_disease_components=10

In [295]:
Wh = sparse.random(Nsamples, num_healthy_components+num_shared_components, density=0.1, random_state=rs, data_rvs=rvs).tocsr()
Hh = sparse.random(num_healthy_components+num_shared_components, Nfeatures, density=0.1, random_state=rs, data_rvs=rvs).tocsr()
Hh = Hh/np.max(Hh)

Wd = sparse.random(Nsamples, num_healthy_components+num_shared_components, density=0.1, random_state=rs, data_rvs=rvs).tocsr()
Wd = Wd/np.max(Wd)
Hd = sparse.random(num_healthy_components+num_shared_components, Nfeatures, density=0.1, random_state=rs, data_rvs=rvs).tocsr()
Hd = Hd/np.max(Hd)

Wd[:,:num_shared_components] = Wh[:,:num_shared_components]

Xh = safe_sparse_dot(Wh, Hh) #+ np.random.randn(Nsamples, Nfeatures)*0.001
Xd = safe_sparse_dot(Wd, Hd) #+ np.random.randn(Nsamples, Nfeatures)*0.001
Xh = np.abs(Xh)
Xd = np.abs(Xd)

In [296]:
Wd.shape, Wh.shape, Hh.shape, Hd.shape

((6000, 15), (6000, 15), (15, 2000), (15, 2000))

In [297]:
model = jointNMF.JointNMF(Xh, Xd, gamma=50, mu=0.1, 
                          nsh_components=num_shared_components, 
                          nh_components=num_healthy_components,
                          nd_components=num_disease_components)

In [298]:
model.Wh.toarray()

array([[2.39814340e-01, 0.00000000e+00, 0.00000000e+00, ...,
        3.70929313e-06, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [2.19828051e-01, 0.00000000e+00, 0.00000000e+00, ...,
        4.36837605e-01, 0.00000000e+00, 5.62263197e-07],
       ...,
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        4.58523046e-07, 3.83303425e-01, 5.33405393e-01],
       [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00]])

In [299]:
nsh_components=num_shared_components
nh_components=num_healthy_components
nd_components=num_disease_components
nmfh = NMF(n_components = nsh_components + nh_components)
nmfd = NMF(n_components = nsh_components + nd_components)

model.Wh = scipy.sparse.csr_matrix(nmfh.fit_transform(Xh))
model.Hh = scipy.sparse.csr_matrix(nmfh.components_)

model.Wd = scipy.sparse.csr_matrix(nmfd.fit_transform(Xd))
model.Hd = scipy.sparse.csr_matrix(nmfd.components_)

model.maxiters = 1000
model.tol = 1E-06
model.gamma = 50
model.mu=0.5
model.Xh = scipy.sparse.csr_matrix(Xh).copy()
model.Xd = scipy.sparse.csr_matrix(Xd).copy()

In [300]:
model.Wd, Xh

(<6000x15 sparse matrix of type '<class 'numpy.float64'>'
 	with 15985 stored elements in Compressed Sparse Row format>,
 <6000x2000 sparse matrix of type '<class 'numpy.float64'>'
 	with 1581555 stored elements in Compressed Sparse Row format>)

In [301]:
scale2=np.append((model.gamma+model.mu)*np.ones(nsh_components), (model.mu)*np.ones(nh_components))
scale2
model.Wshd = model.Wd[:,:nsh_components]
Wh_up1 = safe_sparse_dot(model.Xh, model.Hh.T)

In [302]:
scale2=np.append((model.gamma+model.mu)*np.ones(nsh_components), (model.mu)*np.ones(nh_components))
model.Wshd = model.Wd[:,:nsh_components]
Wh_up1 = safe_sparse_dot(model.Xh, model.Hh.T)
Wshd_transform = model.Wshd.multiply(scipy.sparse.csr_matrix(model.gamma*np.ones((model.Xh.shape[0], nsh_components))))
zeros = scipy.sparse.csr_matrix(np.zeros((model.Xh.shape[0], nh_components)))
Wh_up2 = scipy.sparse.hstack((Wshd_transform, zeros)) #+ _smallnumber2
Wh_down = safe_sparse_dot(model.Wh, safe_sparse_dot(model.Hh, model.Hh.T)) + safe_sparse_dot(model.Wh, np.diag(scale2)) #+ _smallnumber2
Wh_temp = model.Wh.multiply((Wh_up1 + Wh_up2)/Wh_down).tocsr() 

In [303]:
_smallnumber2 = 1E-06
model.Wd.array = model.Wd.toarray()
model.Wh.array = model.Wh.toarray()
model.Xh.array = model.Xh.toarray()
model.Hh.array = model.Hh.toarray()

model.Wshd.array = model.Wd.array[:,:nsh_components]
Wh_up = np.dot(model.Xh.array, model.Hh.array.T) + \
        np.append(model.Wshd.array * (model.gamma*np.ones((model.Xh.array.shape[0], nsh_components))), \
        np.zeros((model.Xh.array.shape[0], nh_components)), 1) + _smallnumber2 
Wh_down = np.dot(model.Wh.array, np.dot(model.Hh.array, model.Hh.array.T)) + \
            np.dot(model.Wh.array, np.diag(scale2)) + _smallnumber2 
Wh_temp2 = model.Wh.array*(Wh_up/Wh_down)

In [304]:
Wh_temp.toarray(), Wh_temp2

(array([[1.98490694e-01, 0.00000000e+00, 0.00000000e+00, ...,
         3.70923270e-06, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.93082957e-01, 0.00000000e+00, 0.00000000e+00, ...,
         4.35242504e-01, 0.00000000e+00, 5.62261572e-07],
        ...,
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         4.58521174e-07, 3.81476143e-01, 5.29900571e-01],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 array([[1.98490695e-01, 0.00000000e+00, 0.00000000e+00, ...,
         3.70923270e-06, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [1.93082957e-01, 

In [305]:
np.max(Wh_temp2 - Wh_temp), np.min(Wh_temp2 - Wh_temp)

(3.3668366428152297e-09, -5.130659119778613e-09)

In [306]:
Hh_up = safe_sparse_dot(model.Wh.T, model.Xh) #+ _smallnumber2
Hh_down = safe_sparse_dot(safe_sparse_dot(model.Wh.T, model.Wh), model.Hh) #+ _smallnumber2
Hh_temp = model.Hh.multiply(Hh_up/Hh_down).tocsr()

In [307]:
Hh_up = np.dot(model.Wh.array.T, model.Xh.array) + _smallnumber2 
Hh_down = np.dot(np.dot(model.Wh.array.T, model.Wh.array), model.Hh.array)  + _smallnumber2 
Hh_temp2 = model.Hh.array*(Hh_up/Hh_down)

In [308]:
np.max(Hh_temp2 - Hh_temp), np.min(Hh_temp2 - Hh_temp)

(6.672440377997191e-14, -1.1313172620930345e-13)

In [313]:
diff1 = 0.5*sparse.linalg.norm(model.Xh - safe_sparse_dot(model.Wh, model.Hh), ord='fro')**2
diff2 = 0.5*sparse.linalg.norm(model.Xd - safe_sparse_dot(model.Wd, model.Hd), ord='fro')**2
diff3 = (model.mu/2)*(sparse.linalg.norm(model.Wh, ord='fro')**2) + (model.mu/2)*(sparse.linalg.norm(model.Wd, ord='fro')**2)
model.Wshh = model.Wh[:,:nsh_components]
model.Wshd = model.Wd[:,:nsh_components]
diff4 = (model.gamma/2)*sparse.linalg.norm(model.Wshh-model.Wshd, ord='fro')**2
chi2 = diff1 + diff2 + diff3 + diff4
chi2, diff1, diff2, diff3, diff4, ((model.mu/2)*sparse.linalg.norm(model.Wh, ord='fro')**2), (model.mu/2)*(sparse.linalg.norm(model.Wd, ord='fro'))

(16936.13700859847,
 7.577304680324081e-06,
 5.840001328046725e-06,
 711.2542196709505,
 16224.882775510216,
 363.7036678711697,
 9.321353868936916)

In [314]:
diff1 = 0.5*np.linalg.norm(model.Xh - np.dot(model.Wh.toarray(), model.Hh.toarray()), ord='fro')**2
diff2 = 0.5*np.linalg.norm(model.Xd - np.dot(model.Wd.toarray(), model.Hd.toarray()), ord='fro')**2
diff3 = (model.mu/2)*(np.linalg.norm(model.Wh.toarray(), ord='fro')**2) + (model.mu/2)*(np.linalg.norm(model.Wd.toarray(), ord='fro')**2)
diff4 = model.gamma/2*np.linalg.norm(model.Wshh.toarray()-model.Wshd.toarray(), ord='fro')**2
chi2 = diff1 + diff2 + diff3 + diff4 
chi2, diff1, diff2, diff3, diff4, (model.mu/2)*(np.linalg.norm(model.Wh.toarray(), ord='fro')**2), (model.mu/2)*(np.linalg.norm(model.Wd.toarray(), ord='fro')**2)

(16936.13700859847,
 7.577304680324254e-06,
 5.840001328046614e-06,
 711.2542196709508,
 16224.882775510216,
 363.7036678711699,
 347.5505517997808)