Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Threading performance needs to be evaluated #1

Open
kwalcock opened this issue Aug 30, 2021 · 19 comments
Open

Threading performance needs to be evaluated #1

kwalcock opened this issue Aug 30, 2021 · 19 comments

Comments

@kwalcock
Copy link
Owner

kwalcock commented Aug 30, 2021

TorchScript seems to be thread-safe and thread-efficient. It gets the same answer each time and gets it faster with more threads. This was run on a 16-core, 32-processor machine. For the first few doublings of the thread count, performance is about 150%. Overall it maxes out at 7x when one would hope for 16-32 times. Perhaps it would go higher on Clara.

image

image

image

Typos have been fixed in the graph captions.

@MihaiSurdeanu
Copy link

Thanks!

So, it seems to me it would be worth it to change our Metal DyNet library to TorchScript, no? What is your opinion?
Also, what are the RAM requirements for TorchScript for this task?

@kwalcock
Copy link
Owner Author

kwalcock commented Aug 31, 2021

That RAM question is going to be difficult not just because it's Java but also because any underlying C-like code may have its own stash that Java might not know about. I can at least google and experiment.

<opinion>

This is still very much an apples vs. oranges comparison I think. With TorchScript there can be (at least) a 7x speedup with threading, but maybe in the end on a complicated model it is 10x slower than DyNet and can't catch up. On one test I believe at the FatDynet level and maybe on Clara, DyNet could be sped up 20x (clulab/processors#422 (comment)) and might eventually work from Scala. It looks like TorchScript can easily run on the GPU (https://towardsdatascience.com/pytorch-jit-and-torchscript-c2a77bac0fff) and that could help it catch up again, In general we're threading on a per document basis, but often we run only one document at a time anyway, so we may lose speed on most small jobs if it is only the multi-threading that justifies TorchScript. It might be interesting to try threading per sentence on either one.

It would be nice to have alternatives if they don't cost too much. Processors could work on some InferenceEngine interface with different implementations or could do something like or with this ONNX (https://github.com/onnx/onnx) project. It almost sounds reasonable as a learning exercise for someone to implement that same model on numerous platforms. However, it might take away from bug hunting which could pay off very quickly (or never). There are lots of other places where performance might be improved, but they might be less interesting. I've wondered what would happen if all these Seqs were turned into Lists, for example. Some measurements show them to be a lot faster (https://www.lihaoyi.com/post/BenchmarkingScalaCollections.html) for many operations.

</opinion>

@MihaiSurdeanu
Copy link

I agree it's an apple vs. oranges comparison. But, if we attempt to normalize them by accounting for the fact that the DyNet code is more complex, they become sort of similar. That is, our DyNet code on clara is about 50% slower than the TorchScript.
In general, TorchScript is appealing for three reasons:

  • Seems to like parallelism more than DyNet. This reason might disappear if you find that bug.
  • PyTorch is much better supported than DyNet.
  • PyTorch has pre-trained transformer networks. DyNet does not.

@MihaiSurdeanu
Copy link

Let's discuss Thursday.

@kwalcock
Copy link
Owner Author

1 thread requires approximately 145MB of Java memory. 8 threads require approximately 149MB. Each requires about 1/2MB. The model itself on disk is 935.5MB. There would seem to be very little of the memory managed by Java.

1 thread

Memory Time (sec)
8192 193.9408116
4096 191.989438
2048 193.5145225
1024 190.8282388
512 191.0930682
256 191.7231144
192 192.7174266
160 195.3915921
152 199.9521839
148 205.5872125
146 231.19455
145 228.9728285
144 Java heap space
128 GC overhead limit exceeded

8 threads

Memory Time (sec)
256 62.28951887
192 69.24889569
160 66.72487354
152 70.56916087
150 74.55402432
149 76.57680217
148 GC overhead limit exceeded
144 Java heap space

@kwalcock
Copy link
Owner Author

C memory is next.

@MihaiSurdeanu
Copy link

MihaiSurdeanu commented Aug 31, 2021 via email

@kwalcock
Copy link
Owner Author

The embeddings themselves are not used, but rather only the vocabulary (

and
# Load word embeddings, (we just care about which tokens are in our vocab)
). The input is the index of the word in the vocabulary and the index of the label from conll. In the Scala code it happens at
Tensor.fromBlob(tokenIndexes, Array(1L, tokenIndexes.length.toLong)),
and in Python at
out_sent = torch.tensor(out_sent)
. That's just how the code was when I got it.

The embeddings do show up in a round about way at

self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
and I haven't tracked down what that's about. There is no equivalent in the TorchScript code except maybe in the loading of the stored model.

@MihaiSurdeanu
Copy link

I think he is using random embeddings instead, which is fine for this experiment.
Thank you!

@kwalcock
Copy link
Owner Author

kwalcock commented Sep 1, 2021

I have not found any good (and easy and accurate) way to measure how much memory is being used by the libraries. jcmd doesn't seem to help, /proc/pid doesn't divulge secrets, PyTorch doesn't provide any insight through the Java interface, etc. As a hack, I ran the test program allowing for the minimal possible Java memory setting from above so that Java didn't have any to spare there and then just used top to check how much memory was reserved for the process. For both 1 and 8 threads, 4.9GB was recorded. A simple hello world program generated by the same project measured just 64MB and provides a sanity check. If I run sbt test on processors, the same column often shows between 6.6 and 7.1GB.

@kwalcock
Copy link
Owner Author

kwalcock commented Sep 1, 2021

There seems to be a difficulty using torchscript on a Mac. It looks like "System Integrity Protect" may need to be disabled from the GUI. There are some links about it towards the bottom of https://github.com/kwalcock/torchscript. I haven't been able to check yet whether it solves the problem.

@MihaiSurdeanu
Copy link

MihaiSurdeanu commented Sep 1, 2021 via email

@kwalcock
Copy link
Owner Author

These are mean times (sec,) for a forward pass of one sentence for a single thread. Onnx is close to 2.5x faster here.

Library Train set Val set Test set
TorchScript 0.00053581 0.00056500 0.00049214
Onnx 0.00022426 0.00022952 0.00020155

@kwalcock
Copy link
Owner Author

kwalcock commented Sep 24, 2021

Onnx seems to bottom out at 16 threads, which happens to be the number of cores in the computer. Even though this happens sooner, the performance is still better than PyTorch, which was still taking 111 seconds to run while Onnx was down to 86.

image

image

image

@MihaiSurdeanu
Copy link

MihaiSurdeanu commented Sep 24, 2021 via email

@kwalcock
Copy link
Owner Author

There is a library dependency, but I haven't looked inside the jar:

    "com.microsoft.onnxruntime"  % "onnxruntime" % "1.8.1",

PyTorch requires jars plus several C-libraries in the LD_LIBRARY_PATH:

    // This one requires the next.
    // "org.pytorch"           % "pytorch_java_only" % "1.9.0",
    // The next one can't be found.  Use jars in lib directory.
    // "com.facebook.fbjni"    % "fbjni-java-only"   % "0.0.3",
    // And this is a transitive dependency
    // "com.facebook.soloader" % "nativeloader"      % "0.8.0",

@MihaiSurdeanu
Copy link

It probably means they have an interpreter, which is possibly natively supported.

What NN did you run exactly? Just a feed forward one, or the double LSTM?

@kwalcock
Copy link
Owner Author

This is the same model from Peter as before:

self.bilstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
bidirectional=True,
batch_first=True,
)

@MihaiSurdeanu
Copy link

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants