観測行列$\bf{Y}\in\bf{R_+}^{m\times n}$, 辞書行列$\bf{H}\in\bf{R_+}^{m\times k}$, 係数行列$\bf{H}\in\bf{R_+}^{k\times n}$とする。ただし$k\leq m$

以下のクラスは観測行列$\bf{Y}$と辞書の要素数$k$を取り、$$\bf{Y}=\bf{H}\bf{U}$$となる$\bf{U}$と$\bf{H}$を見つけるNMFを計算するクラスである。

In [129]:
import numpy as np
import math as math

class NMF(object):
    ''' caliculates NMF'''
    ''' requires 2-dimentional numpy array'''
    def __init__(self, Y, k):
        assert len(np.shape(Y)) == 2, 'Matrix size error (shold be 2-d array)'
        self.m = np.shape(Y)[0] 
        self.n = np.shape(Y)[1]
        self.Y = Y
        self.k = k
        self.__initialize()
    '''update dictionary and coefficient'''
    def update(self):
        # aliases
        H = self.H
        U = self.U
        Y = self.Y
        self.H = H*Y.dot(U.T)/H.dot(U.dot(U.T))
        H = self.H
        self.U = U*H.T.dot(Y)/H.T.dot(H).dot(U)
    '''initialize dictionary and coefficient'''
    def __initialize(self): 
        self.H = self.Y[:,:self.k].copy()
        self.U = np.random.ranf(self.k*self.n).reshape(self.k,self.n)            
    '''returns dictionary'''
    def get_dic(self):
        return self.H
    '''returns coefficient'''
    def get_coef(self):
        return self.U
    '''evaluate error by '''
    def error(self):
        diff=self.Y-self.H.dot(self.U)
        return math.sqrt(sum(sum(diff*diff)))
    
 

下記の関数は、$\bf{U}$が列に丁度$n\_sparse$個の要素を含むような行列とし、$\bf{H}$を乱数としたときの$\bf{Y}=\bf{H}\bf{U}$を計算する。$\bf{Y}$はスパースな成分に分解できる行列となる。

In [33]:
''' utirity functions for NMF evaluation '''
def NMF_eval( m=20, n=1500, k=15, n_sparse = 3):

    ''' generate KxN matrix with (n_sparse)-sparse '''
    def create_sparse(k, n, n_sparse):
        elem = np.random.ranf(n_sparse*n).reshape(n_sparse,n)
        zeros = (np.zeros((k-n_sparse)*n).reshape((k-n_sparse),n))
        X = np.concatenate((elem, zeros), axis=0)
        X = [np.random.permutation(x) for x in np.transpose(X)]
        return np.transpose(X)

    
    ''' generate matrix contains sparse components '''
    ''' X = UY, U=random, Y=sparse matrix'''
    def create_data(m,n,k,n_sparse):
        U = create_sparse(k,n,n_sparse)
        H = np.random.ranf(m*k).reshape(m,k) # R(KxM)
        Y = H.dot(U)
        return Y,H,U
    
    def evaluate(Y, H, U, trueDic):
        def count_atoms(_est_A, _true_A, axis=0):
            from sklearn.preprocessing import normalize
            """ Count recovered atoms
            Parameters
            ----------
            _est_A : array, shape(n_features, n_samples)
                estimated matrix
            _true_A : array, shape(n_features, n_samples)
                true matrix
            axis : 0 or 1, optional (1 by default)
                axis used to normalize the data along.
                If 1,            independently normalize each sample,
                otherwise (if 0) normalize each feature.
            """

            if not (_est_A.shape == _true_A.shape):
                raise ValueError("The shape of dictionaries should be same;"
                                 "got %r and %r " % (_est_A.shape, _true_A.shape))
            num_recovered = 0
            num_atoms = len(_true_A[0])

            est_A = normalize(_est_A, axis=axis)
            true_A = normalize(_true_A, axis=axis)

            for e in est_A.T:

                distances = [np.mean((e - t) ** 2) for t in true_A.T]

                min_idx = np.argmin(distances)
                min_t = true_A[:, min_idx]

                # Assume normalization for atoms (Aatom' * Eatom = 1)
                dis_t = 1 - abs(np.dot(min_t.T, e))

                if (dis_t < 0.01):
                    num_recovered += 1

            recovered_rate = 100 * (num_recovered / num_atoms)

            # print("recovered rate: %r" % recovered_rate)
            return recovered_rate
        
        error = sum(sum(pow(Y-np.dot(H,U),2))) # Error distance
        atom = count_atoms(H, trueDic)          # Ratio of recoverd atom
        return error, atom  
    
    # initialize
    error = [];
    atom = [];
    Y,trueH,trueU = create_data(m,n,k, n_sparse)
    res = NMF(Y,k)
    print(np.shape(Y), np.shape(trueH), np.shape(trueU))
    print('test')
    # update dictionary and coefficients
    for i in range(0,10000):
        res.update()
        H = res.get_dic()
        U = res.get_coef()
        err, atm = evaluate(Y, H, U, trueH)
        error.append(err)
        atom.append(atm)
        print(err,atm)
        
    #return error, atom

In [34]:
NMF_eval()

(20, 1500) (20, 15) (15, 1500)
test
2038.09591618 0.0
1993.3084236 0.0
1950.00708558 0.0
1905.10614802 0.0
1856.92907638 0.0
1804.13189025 0.0
1745.8453737 0.0
1681.88903475 0.0
1612.94706247 0.0
1540.52904475 0.0
1466.61252541 0.0
1393.10973529 0.0
1321.46535901 0.0
1252.55794585 0.0
1186.82663389 0.0
1124.45090917 0.0
1065.47744741 0.0
1009.87899582 0.0
957.572705317 0.0
908.425761594 0.0
862.26294127 0.0
818.879885718 0.0
778.060380948 0.0
739.594102593 0.0
703.291363102 0.0
668.992399216 0.0
636.570205031 0.0
605.92750929 0.0
576.989781962 0.0
549.69671667 0.0
523.994344759 0.0
499.829074811 0.0
477.144010439 0.0
455.877254809 0.0
435.961657068 0.0
417.325478195 0.0
399.893584777 0.0
383.588904484 0.0
368.333955984 0.0
354.052307852 0.0
340.669851411 0.0
328.115809639 0.0
316.323450074 0.0
305.230514435 0.0
294.779410151 0.0
284.917223476 0.0
275.595612338 0.0
266.77062548 0.0
258.402479497 0.0
250.455311769 0.0
242.896917385 0.0
235.698472566 0.0
228.834245336 0.0
222.28129528 0.0

17.7251936682 0.0
17.7015184162 0.0
17.6779583807 0.0
17.6545127925 0.0
17.6311808799 0.0
17.607961869 0.0
17.5848549841 0.0
17.5618594484 0.0
17.5389744836 0.0
17.5161993112 0.0
17.4935331527 0.0
17.47097523 0.0
17.4485247661 0.0
17.4261809857 0.0
17.4039431156 0.0
17.3818103853 0.0
17.3597820277 0.0
17.3378572796 0.0
17.3160353819 0.0
17.2943155807 0.0
17.272697127 0.0
17.251179278 0.0
17.229761297 0.0
17.2084424537 0.0
17.1872220251 0.0
17.1660992952 0.0
17.1450735557 0.0
17.1241441062 0.0
17.1033102545 0.0
17.0825713164 0.0
17.0619266163 0.0
17.0413754873 0.0
17.0209172712 0.0
17.0005513183 0.0
16.980276988 0.0
16.9600936486 0.0
16.940000677 0.0
16.919997459 0.0
16.9000833893 0.0
16.880257871 0.0
16.8605203161 0.0
16.8408701447 0.0
16.8213067855 0.0
16.801829675 0.0
16.782438258 0.0
16.7631319867 0.0
16.7439103211 0.0
16.7247727281 0.0
16.7057186821 0.0
16.6867476637 0.0
16.6678591605 0.0
16.6490526659 0.0
16.6303276793 0.0
16.6116837057 0.0
16.5931202552 0.0
16.5746368429 0.0
16.5

12.4235553134 0.0
12.4178683515 0.0
12.4121905405 0.0
12.4065218326 0.0
12.4008621815 0.0
12.3952115422 0.0
12.3895698712 0.0
12.3839371261 0.0
12.3783132659 0.0
12.372698251 0.0
12.3670920432 0.0
12.3614946053 0.0
12.3559059017 0.0
12.3503258979 0.0
12.3447545608 0.0
12.3391918586 0.0
12.3336377605 0.0
12.3280922372 0.0
12.3225552605 0.0
12.3170268036 0.0
12.3115068406 0.0
12.305995347 0.0
12.3004922995 0.0
12.2949976757 0.0
12.2895114546 0.0
12.2840336161 0.0
12.2785641413 0.0
12.2731030124 0.0
12.2676502124 0.0
12.2622057256 0.0
12.256769537 0.0
12.2513416329 0.0
12.2459220001 0.0
12.2405106266 0.0
12.235107501 0.0
12.2297126129 0.0
12.2243259525 0.0
12.2189475108 0.0
12.2135772793 0.0
12.2082152504 0.0
12.2028614168 0.0
12.1975157718 0.0
12.1921783092 0.0
12.186849023 0.0
12.1815279078 0.0
12.1762149582 0.0
12.1709101691 0.0
12.1656135355 0.0
12.1603250526 0.0
12.1550447155 0.0
12.1497725191 0.0
12.1445084584 0.0
12.1392525281 0.0
12.1340047224 0.0
12.1287650355 0.0
12.123533461 0.

KeyboardInterrupt: 