Permalink
Browse files

Add a pLSA implementation.

  • Loading branch information...
1 parent 6593fa6 commit d6eebcd2e6ba6f7528ed0d6491d98434aac2026c @tomtung tomtung committed Oct 19, 2011
Showing with 138 additions and 0 deletions.
  1. +4 −0 MATLAB/pLSA/README
  2. BIN MATLAB/pLSA/data.7z
  3. +66 −0 MATLAB/pLSA/pLSA.m
  4. +32 −0 MATLAB/pLSA/prepare.m
  5. +15 −0 MATLAB/pLSA/run.m
  6. +21 −0 MATLAB/pLSA/show_result.m
View
@@ -0,0 +1,4 @@
+Extract data.txt from data.7z, and execute the run script.
+
+Notes on this algorithm (written in Chinese):
+http://blog.tomtung.com/2011/10/plsa/
View
Binary file not shown.
View
@@ -0,0 +1,66 @@
+% Given the number of occurrence of word w in document d n_dw(d,w), and the number of topics to discover n_z, this function returns p(w|z) and p(z|d)
+function [p_w_given_z, p_z_given_d] = pLSA(n_dw, n_z)
+
+% filter out words that are too common or too obscure
+for w = 1:size(n_dw,2)
+ if size(nonzeros(n_dw(:,w)),1) > size(n_dw,1)*1.5/n_z || size(nonzeros(n_dw(:,w)),1) <= size(n_dw,1)/(n_z*10)
+ n_dw(:,w) = 0;
+ end
+end
+
+% pre-allocate space
+[n_d, n_w] = size(n_dw); % max indices of d and w
+p_z_given_d = rand(n_z, n_d); % p(z|d)
+p_w_given_z = rand(n_w, n_z); % p(w|z)
+n_p_z_given_dw = cell(n_z, 1); % n(d,w) * p(z|d,w)
+for z = 1:n_z
+ n_p_z_given_dw{z} = sprand(n_dw);
+end
+
+p_dw = sprand(n_dw); % p(d,w)
+L = intmin; % log-likelihood
+improvement = intmax; % used to determine convergence
+while improvement > 0.001*abs(L)
+ disp('E-step');
+ for d = 1:n_d
+ for w = find(n_dw(d,:))
+ for z = 1:n_z
+ n_p_z_given_dw{z}(d,w) = p_z_given_d(z,d) * p_w_given_z(w,z) * n_dw(d,w) / p_dw(d, w);
+ end
+ end
+ end
+
+ disp('M-step');
+ disp('update p(z|d)')
+ concat = cat(2, n_p_z_given_dw{:}); % make n_p_z_given_dw{:}(d,:)) possible
+ for d = 1:n_d
+ for z = 1:n_z
+ p_z_given_d(z,d) = sum(n_p_z_given_dw{z}(d,:));
+ end
+ p_z_given_d(:,d) = p_z_given_d(:,d) / sum(concat(d,:));
+ end
+
+ disp('update p(w|z)')
+ for z = 1:n_z
+ for w = 1:n_w
+ p_w_given_z(w,z) = sum(n_p_z_given_dw{z}(:,w));
+ end
+ p_w_given_z(:,z) = p_w_given_z(:,z) / sum(n_p_z_given_dw{z}(:));
+ end
+
+ % update p(d,w) and calculate likelihood to determine convergence
+ prevL = L; L = 0;
+ for d = 1:n_d
+ for w = find(n_dw(d,:))
+ p_dw(d,w) = 0;
+ for z = 1:n_z
+ p_dw(d,w) = p_dw(d,w) + p_w_given_z(w,z) * p_z_given_d(z,d);
+ end
+ L = L + n_dw(d,w) * log(p_dw(d, w));
+ end
+ end
+ improvement = L - prevL;
+
+ fprintf('L = %f, improvement = %f%%\n', L, improvement/abs(L)*100);
+end
+disp('Improved less than 0.1%. Algorithm converged.')
View
@@ -0,0 +1,32 @@
+% This function reads data, prepare and save it to prepared_data.mat
+function [] = prepare(data)
+word2Index = containers.Map;
+words = cell(50000, 1);
+n_dw = spalloc(1000000, 500000, 50000000);
+nDoc = 0;
+
+% Read data and get prepared
+fid = fopen(data);
+while ~feof(fid)
+ nDoc = nDoc + 1;
+ fprintf('Loading document %d\n', nDoc)
+ line = regexprep(lower(fgetl(fid)), '[^\w ]+', ' ');
+ wordsInLine = textscan(line, '%s');
+ for i = 1:size(wordsInLine{1}, 1), word = wordsInLine{1}{i};
+ if length(word) < 4, continue; end
+ if ~isKey(word2Index, word)
+ w = size(word2Index, 1) + 1;
+ word2Index(word) = w;
+ words{w} = word;
+ else
+ w = word2Index(word);
+ end
+ n_dw(nDoc, w) = n_dw(nDoc, w) + 1;
+ end
+end
+fclose(fid);
+
+n_dw = n_dw(1:nDoc,1:size(word2Index));
+words = words(1:size(word2Index));
+
+save prepared_data.mat word2Index words n_dw
View
@@ -0,0 +1,15 @@
+% This script prepares the data, trains the model, saves and prints the result
+
+if ~exist('prepared_data.mat', 'file')
+ data = 'data.txt';
+ disp('Preparing data ...')
+ prepare(data);
+end
+load prepared_data.mat
+
+n_z = 15; % number of topics to discover
+disp('Training pLSA model ...')
+[p_w_given_z, p_z_given_d] = pLSA(n_dw, n_z);
+save result.mat word2Index words n_dw n_z p_w_given_z p_z_given_d
+
+show_result
@@ -0,0 +1,21 @@
+% This script prints the result saved in result.mat
+load result.mat
+n_kw = 10; % find 10 keywords for each topic
+for z = 1:n_z
+ fprintf('Key words for topic %d:\n', z);
+ [S, I] = sort(p_w_given_z(:,z), 'descend');
+ for w = I(1:n_kw)'
+ fprintf('%d %s\t(%f)\n', w, words{w}, p_w_given_z(w,z))
+ end
+ fprintf('\n')
+end
+fprintf('\n')
+
+n_d_show = 10; % show p(z|d) for first 10 documents
+for d = sort(randsample(size(n_dw,1), n_d_show))'
+ fprintf('Topic weights for document %d:\n', d);
+ for z = 1:n_z
+ fprintf('%f\t', p_z_given_d(z,d))
+ end
+ fprintf('\n')
+end

0 comments on commit d6eebcd

Please sign in to comment.