In [405]:
from sklearn.linear_model import orthogonal_mp
import numpy as np
import os
import re
from PIL import Image

In [406]:
class image_parser(object):
    
    def __init__(self,datapath=None,folder=None,images=None,image_names=None,
                 min_rows=10,min_cols=10,verbose=False):
        self.datapath=datapath
        self.folder=folder
        self.images=images
        self.image_names=image_names
        self.min_rows=min_rows
        self.min_cols=min_cols
        self.verbose=verbose

    def read(self):
        self.images=np.empty(self.min_rows*self.min_cols)
        self.image_names=[]
        for base, dirs, files in os.walk (self.datapath+'/'+self.folder+'/'):
            for filename in files:
                if self.verbose: print("reading..."
                                  +self.datapath+'/'+self.folder+'/'+filename)
                name_JPEG = re.match (r'^(.*)\.JPEG$',filename)
                if name_JPEG:
                    filepath = os.path.join (base, filename)
                    image = Image.open (filepath,'r'
                                       ).resize((self.min_rows,self.min_cols)).convert("L")
                    image = np.array(image).reshape(-1)
                    self.images=np.vstack((self.images,image))
                    self.image_names.append(filename)

In [407]:
datapath="/Users/Heather/Desktop/gatech/spring 2016/6240 web search and text mining/project2"

In [408]:
# initialize the dictionary D with random numbers and normalize the columns.
def genDict(num):
    D=np.random.random(size=(400, num))
    for i in range(num):
        D[:,i]=D[:,i]/np.linalg.norm(D[:,i]) 
    return D

In [None]:
# following is encoding for animal: 

In [409]:
animals=image_parser(folder="Animal",datapath=datapath,
                     verbose=False,min_rows=20,min_cols=20)
animals.read()

# romove the first row which is randomly generated.
data=animals.images[1:]/255
length=len(data)

# divide the data into training data and testing data. We use Transpose to make each column a data sample
Y=data.T[:,0:int(length*0.8)]
Y_test=data.T[:,int(length*0.8):]
D=genDict(800)

# use orthogonal matching pursuit algorithm to find X
X=orthogonal_mp(D,Y)

In [410]:
# update X using orthogonal matching pursuit algorithm
def updateX(Y,D):
    return orthogonal_mp(D,Y)

In [411]:
# calculate the approximation error based on given Y, D and X
def Error(Y,D,X):
    return np.linalg.norm(Y-D.dot(X))

In [412]:
# update D using K-SVD algorithm
def updateD(Y,D,X):
    
    # initialize the new dictionary
    D_update=np.zeros(D.shape)
    
    for i in range(D.shape[1]):
        
        # from the whole dataset select data which uses the ith column in the dictionary and only use this part of data to 
        # update the ith column
        X_select=X[:,X[i,:]!=0]
        
        # if this part is empty, we update the ith column using the data which performs worst according to current dictionary
        if(X_select.shape[1]==0):
            max_norm=0
            p=0
            # error matrix:
            Error_M=Y-D.dot(X)
            # find the vector with largest norm in the error matrix
            for k in range(Error_M.shape[1]):
                norm=np.linalg.norm(Error_M[:,k])
                if norm>max_norm:
                    max_norm = norm
                    p=k
            
            # update the ith column of dictionary
            D_update[:,i]=Y[:,p]/np.linalg.norm(Y[:,p]) 
        
        # if this part isn't empty, extract the error from this part 
        else:
            Y_new=Y[:,X[i,:]!=0]
            X_select[i,:]=0
            # error matrix
            Error_M=Y_new-D.dot(X_select)
            # do SVD on error matrix
            U, s, V=np.linalg.svd(Error_M)
            # replace the ith column by the first column of U since it catch most variance in the data
            D_update[:,i]=U[:,0]
            
    return D_update

In [413]:
Error(Y,D,X)

86.111560165807916

In [414]:
for i in range(20):
    D=updateD(Y,D,X)
    X=updateX(Y,D)
    print Error(Y,D,X)

65.2337463313
53.8866971475
46.4485374423
41.6406061438
39.4341427901
38.6851488448
38.3391443761
37.7899012816
38.8612397783
39.9344566316
39.4689410998
40.8532195598
44.3826334797
45.8045774933
44.9009540147
44.7595090793
45.7664845419
47.0307400315
46.8208691099
46.3348112512


In [417]:
X_test=updateX(Y_test,D)

In [419]:
# test error:
Error(Y_test,D,X_test)

33.456659553443302

In [420]:
# following is encoding for Geological Formation: 

In [426]:
GF=image_parser(folder="Geological Formation",datapath=datapath,
                     verbose=False,min_rows=20,min_cols=20)
GF.read()
data=GF.images[1:]/255
length=len(data)
Y=data.T[:,0:int(length*0.8)]
Y_test=data.T[:,int(length*0.8):]
D=genDict(800)
X=orthogonal_mp(D,Y)

In [427]:
Error(Y,D,X)

93.680863878782318

In [428]:
for i in range(10):
    D=updateD(Y,D,X)
    X=updateX(Y,D)
    print Error(Y,D,X)

72.649544347
59.6851816651
51.2164930668
45.9464719343
43.7271564127
43.4378373467
43.2707195374
43.0392165072
44.633161527
47.0644265809


In [429]:
X_test=updateX(Y_test,D)

In [430]:
# test error:
Error(Y_test,D,X_test)

32.448493255595601

In [None]:
# following is encoding for the mixture of animal and Geological Formation: 

In [433]:
animals=image_parser(folder="Animal",datapath=datapath,
                     verbose=False,min_rows=20,min_cols=20)
animals.read()
data1=animals.images[1:]/255

# 1 stands for animals
data1=np.append(data1,np.ones((data1.shape[0],1)),1)


GF=image_parser(folder="Geological Formation",datapath=datapath,
                     verbose=False,min_rows=20,min_cols=20)
GF.read()
data2=GF.images[1:]/255

# 0 stands for geological formation
data2=np.append(data2,np.zeros((data2.shape[0],1)),1)


In [434]:
data=np.concatenate((data1,data2)).T
np.random.shuffle(data.T)
length=data.shape[1]
Y=data[:,0:int(length*0.8)]
Y_test=data[:,int(length*0.8):]
D=genDict(800)

In [435]:
X=orthogonal_mp(D,Y[:-1,:])

In [436]:
Error(Y[:-1,:],D,X)

127.79280337011205

In [437]:
for i in range(10):
    D=updateD(Y[:-1,:],D,X)
    X=updateX(Y[:-1,:],D)
    print Error(Y[:-1,:],D,X)

99.6999335181
83.9641940303
73.705866032
66.9757887887
62.6332178486
59.9866321084
58.4929990043
57.5198555972
56.5891252133
55.6043674915


In [438]:
# test error:
X_test=updateX(Y_test[:-1,:],D)
Error(Y_test[:-1,:],D,X_test)

42.4149272082769