Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large dataset #5

Closed
macsermkiat opened this issue Oct 26, 2022 · 4 comments
Closed

Large dataset #5

macsermkiat opened this issue Oct 26, 2022 · 4 comments

Comments

@macsermkiat
Copy link

Hi, this is a great tutorial! Thank you for sharing.

I have a question about implementing Dragonnet with a large dataset (in my case 200k subjects). Since to calculate loss it needs to construct a large matrix (200k x 200k) in float32 dtype, that cannot fit into memory. Do you have any suggestions?

Thanks

@kochbj
Copy link
Owner

kochbj commented Oct 26, 2022

Glad you enjoyed it. :) I haven't looked at these models particularly recently, so could you clarify for me which part is the bottleneck?

@macsermkiat
Copy link
Author

macsermkiat commented Oct 27, 2022

There are many places involving construction of large tensors. For example, calculate the Euclidean distance between PhiControl and PhiTreatment in Tutorial 3,:

Class AIPW_Metrics:
    def pdist2sq(x,y):
        x2 = tf.reduce_sum(x ** 2, axis=-1, keepdims=True)
        y2 = tf.reduce_sum(y ** 2, axis=-1, keepdims=True)
        dist = x2 + tf.transpose(y2, (1, 0)) - 2. * x @ tf.transpose(y, (1, 0))
        return dist

Since the Phi layer has 200 nodes, the x2 shape is (Nc,200) and y2 is (Nt,200) ;
Nc = number of treatment subjects and Nt = number of control subjects.

Then dist shape will be (Nc, Nt) or roughly (100000,100000) in my dataset.
For float32 dtype, the memory needed is 8 * 10e5 * 10e5 bytes which do not fit into GPU memory.
So I have to move to the CPU, which also affects other functions.

Also, calculating the distance of the large matrix will take a long time due to quadratic complexity.

@kochbj
Copy link
Owner

kochbj commented Oct 29, 2022

Hmm, those calculations are only for calculating nearest neighbors in representation space for validation. A couple quick solutions:

  1. Skip them, they are only used in validation. Do it periodically on frozen models, perhaps in parallel on CPU?
  2. Do the training on the CPU. It is slower, but unless you are trying to do this in a production setting (which I strongly advise against since these algorithms are still pretty immature) it might be okay. You don't need to do thousands of training iterations, so unless you are doing a lot of simulations it could be reasonable?
  3. Try a different validation technique altogether like Alaa's influence functions or credence?

Hope this helps a bit!
-B

https://proceedings.mlr.press/v162/parikh22a.html
https://www.vanderschaar-lab.com/papers/Validating_CI_models_via_IF_manuscript.pdf

@macsermkiat
Copy link
Author

Thank yous. This helps a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants