Skip to content

Commit

Permalink
Improvements to fulltest and README
Browse files Browse the repository at this point in the history
  • Loading branch information
emstoudenmire committed Jul 31, 2017
1 parent 86b6b63 commit 6e9badd
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
Codes based on the paper "Supervised Learning with Quantum-Inspired Tensor Networks"
by Miles Stoudenmire and David Schwab. http://arxiv.org/abs/1605.05775

Also see "Tensor Train Polynomial Models via Riemannian Optimization" by Novikov, Trofimov, and Oseledets for a closely related approach: http://arxiv.org/abs/1605.03795
Also see "Tensor Train Polynomial Models via Riemannian Optimization" by Novikov, Trofimov, and Oseledets for a similar approach: http://arxiv.org/abs/1605.03795


# Code Overview

`fixedL` -- optimize a matrix product state (MPS) with a label index on the central tensor, similar to what is described in the paper arxiv:1605.05775. This MPS parameterizes a model whose output is a vector of 10 numbers (for the case of MNIST). The output entry with the largest value is the predicted label.
`fixedL` -- optimize a matrix product state (MPS) with a label index on the central tensor, similar to what is described in the paper arxiv:1605.05775, but where the label index stays fixed on the central tensor and does not move around during optimization. This MPS parameterizes a model whose output is a vector of 10 numbers (for the case of MNIST). The output entry with the largest value is the predicted label.

`fulltest` -- given an MPS ("wavefunction") generated by the fixedL program, report classification error for the MNIST testing set

Expand Down Expand Up @@ -78,6 +78,21 @@ Other code features:
- The code writes out a file called "sites" shortly after it begins. This holds what ITensor calls a "SiteSet" which is a set of common reference indices to use to allow different MPS tensors created to always share the same set of site indices.
- If the code finds the file "WRITE_WF" (this file can be empty: create it with the command `touch WRITE_WF`) then after optimizing the current bond, the code will write the weight tensor MPS to the file "W" (overwriting it if already present). Once this happens, the code will delete the file "WRITE_WF"

Tips for running fixedL:
- Getting a good initial weight MPS is important before spending a lot of compute time
optimizing over the full training set. A simple way to get a decent initial MPS is
to do some sweeps with a small maxm setting (maxm = 10, say) with Ntrain set very low,
like to 100. Then you can do some sweeps with Ntrain=1000 and finally Ntrain=10000 which
uses the full training set (if Ntrain is larger than the number of images per label,
the code just includes every training image).
- Do some sweeps at a smaller maxm or larger cutoff, which keeps the typical MPS bond
dimension low, before doing the last few sweeps with a larger maxm and smaller cutoff.
- Another really excellent trick for initialization is described in the appendix of the
Novikov et al. paper (arxiv:1605.03795). Train a linear classifer then define an MPS
which gives a model having the same output as the linear classifier model. This trick
can be extended to the single-MPS multi-task case that fixedL uses; this is left
as an exercise to the reader but I may include a code for this here eventually.


# Single program input parameters and code features

Expand Down
16 changes: 15 additions & 1 deletion fulltest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ main(int argc, const char* argv[])
auto datadir = input.getString("datadir","/Users/mstoudenmire/software/tnml/CppMNIST");
auto fname = input.getString("fname","W");
auto imglen = input.getInt("imglen",28);
auto feature = input.getString("feature","series");

//auto labels = stdx::make_array<long>(2,5);
//auto labels = stdx::make_array<long>(7,8);
Expand All @@ -41,7 +42,20 @@ main(int argc, const char* argv[])
}

enum Feature { Normal, Series };
auto ftype = Normal;
auto ftype = Series;
if(feature == "norm" || feature == "normal")
{
ftype = Normal;
}
else if(feature == "series")
{
ftype = Series;
}
else
{
Error(format("feature type \"%s\" not recognized",feature));
}

auto phi = [ftype](Real g, int n) -> Cplx
{
if(g < 0 || g > 255.) Error(format("Expected g=%f to be in [0,255]",g));
Expand Down

0 comments on commit 6e9badd

Please sign in to comment.