Skip to content

Conversation

isidentical
Copy link
Contributor

@isidentical isidentical commented Sep 11, 2023

What does this PR do?

Fixes #4975. Adds a generic LRU cache for the expensive parameter datatype computation function with an optional fallback to the uncached implemnetation in cases where the underlying module is not hashable (since this can be any torch module subclass, it's possible that it might have overriden the __hash__ or added properties which make it unsafe to hash). From my local testing, the perf increase is very noticable (2X to 3X) more benchmarks are below.

Benchmarks

For the following script which loads the LoRA state dict into memory and then runs load_lora_weights/unload_lora_weights cycle 5 times, the results are as following:

min max mean median mean speed-up
baseline (main) 5.44s 6.46s 6.02s 6.31s 1.0x
This PR 1.84s 2.05s 1.96s 1.94s 3.2x

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@patrickvonplaten
Copy link
Contributor

I'm sadly not seeing a big speed-up anymore (after having merged: #4994). Could you maybe double check once quickly if you still see the same speed-up with the current version of main?

@isidentical
Copy link
Contributor Author

Yeah, my own benchmarks also show no noticable speed up with this PR on (I think this was because of non-meta devices and my sequence of optimizations also reduced the need for this). Closing this PR as its currently very good as is!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LoRa loading is extremely inefficient due to repeated datatype queries

2 participants