Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch refactor #168

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open

Conversation

ebolyen
Copy link
Member

@ebolyen ebolyen commented Apr 19, 2022

No description provided.

Copy link
Member Author

@ebolyen ebolyen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mortonjt, some updates:

We've poured over the model a few times, but cannot get it to converge in the same place that the multimodal unit test expects. We think starting conditions or the Adam optimizer may have an outsized effect here, but we're also not sure we didn't miss something.

The general structure of our tensors are: [batch, sample, whatever], and tracing each operation seems to do what we expect.

So our X is drawn from a multinomial to do a bunch of categorical draws at once, giving us OTU indices in the form: [batch, sample]

That goes into the embedding giving us: [batch, sample, latent]
Then we slice the bias and add it to latent after reshaping it to match (maybe something went wrong here, but we've inspected that line a few different ways and it seems to do what we want).

Then we use the decoder, we run a linear model on 1 less dimension, giving us: [batch, sample, ALR_metabolites], then we add zeros to the front of that last dimension and run softmax over it to hopefully have: [batch, sample, P_metabolites].

We then parameterize the multinomial and calculate likelihoods.

As far as we can tell, this is what should be happening, so we don't really know why we get such unstable correlations from the unit-tests, ranging from -0.29 to accidentally passing for U, and always failing by the point we check V.

def mmvec_training_loop(model, learning_rate, batch_size, epochs):
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
betas=(0.8, 0.9), maximize=True)
for epoch in range(epochs):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add better logic here, so that a single epoch represents the correct number of batch draws for the data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mortonjt, the paper seems to imply that an epoch represents a random draw (in batches of course) for each read in the feature-table, but the original code seems to use nnz which I interpret to mean "n-non-zero". So this would be the number of different types of sample:microbe pairs, rather than the number of observations. What was the goal there, and should we replicate that?

Copy link
Collaborator

@mortonjt mortonjt May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was largely to make the concept of epoch more interpretable. And yes, nnz is the number of non-zeros.

One epoch is completed if you read through the entire dataset, which means that you should be able to process all of the reads. Since the batch size is computed over the number of reads, this is used to compute the number of "iterations" within each loop.

So it should read like this : 1 epoch = num iterations / epoch = (total number of reads [aka nnz] ) / (num reads per batch)

We're basically calculating how many batches are within an epoch, in order to read through the entire dataset.

That being said -- I don't think you really need this. I think the current implementation is fine -- we just need a way to make the term epochs interpretable to the user.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think nnz in the older implementation was actually the number of non-zero cells, not the sum of those cells.

It sounds like the goal was to make it the number of reads outright though (sum of the entire table). I think it's probably worth making sure that epoch fits that, if only for the sake of explanation. (It hasn't seemed to matter too much in practice while we've been testing.)

v_r, v_p = spearmanr(pdist(model.V.T), pdist(self.V.T))

self.assertGreater(u_r, 0.5)
self.assertGreater(v_r, 0.5)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always fail by this point, but often fail the u_r test above as well. @mortonjt, we're kind of at a loss here.

mmvec/ALR.py Outdated

forward_dist = forward_dist.log_prob(self.metabolites)

l_y = forward_dist.sum(0).sum()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing the norm that is multiplied against the data likelihood. @mortonjt we aren't 100% sure what its purpose is, but it kind of looks like a weird mean if you squint.

What is the interpretation of this line: https://github.com/biocore/mmvec/blob/master/mmvec/multimodal.py#L137?

Copy link
Collaborator

@mortonjt mortonjt May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, so there are two ways you can deal with the data

  1. You try to use the mini-batches to approximate the loss on the entire dataset
  2. You just compute the per-sample loss for each mini-batch
    For all intents and purposes, I think it is ok to just compute the per-sample loss -- this appears to be an emerging standard in deep learning.

I think taking a mean is very ok. It'll basically be just l_y = forward_dist.sum(0).mean(). I'm able get the tests to pass once I run this model locally.

@mortonjt
Copy link
Collaborator

How about this, let me try to reproduce the findings. Sometimes it may require tweaking learning rates and batch sizes.

U and V does have an identifiability issue, so that is something to consider. The one metric that should always pass is U @ V.

@mortonjt mortonjt marked this pull request as ready for review May 17, 2022 16:09
@mortonjt mortonjt self-requested a review May 17, 2022 16:09
Copy link
Collaborator

@mortonjt mortonjt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the implementation in this pull request is actually correct. We don't expect U and V tests to always pass (this is why we are running SVD after fitting the model). Its the U @ V test that needs to pass.

I'm able to get the tests passing on my side (r>0.5, p<0.05). The only thing that you may want to drop is the total_count argument in the multinomial.

self.encoder = nn.Embedding(num_microbes, latent_dim)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, num_metabolites),
nn.Softmax(dim=2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.input_bias = nn.Parameter(torch.randn(num_microbes))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might have looked at an older commit. We should have that in the current model.

# Three likelihoods, the likelihood of each weight and the likelihood
# of the data fitting in the way that we thought
# LY
z = self.encoder(X)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias = self.input_bias[X]
z = z + bias.view(-1, 1)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, although the .view(-1, 1) looks nicer

mmvec/ALR.py Outdated

forward_dist = forward_dist.log_prob(self.metabolites)

l_y = forward_dist.sum(0).sum()
Copy link
Collaborator

@mortonjt mortonjt May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, so there are two ways you can deal with the data

  1. You try to use the mini-batches to approximate the loss on the entire dataset
  2. You just compute the per-sample loss for each mini-batch
    For all intents and purposes, I think it is ok to just compute the per-sample loss -- this appears to be an emerging standard in deep learning.

I think taking a mean is very ok. It'll basically be just l_y = forward_dist.sum(0).mean(). I'm able get the tests to pass once I run this model locally.

def mmvec_training_loop(model, learning_rate, batch_size, epochs):
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
betas=(0.8, 0.9), maximize=True)
for epoch in range(epochs):
Copy link
Collaborator

@mortonjt mortonjt May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was largely to make the concept of epoch more interpretable. And yes, nnz is the number of non-zeros.

One epoch is completed if you read through the entire dataset, which means that you should be able to process all of the reads. Since the batch size is computed over the number of reads, this is used to compute the number of "iterations" within each loop.

So it should read like this : 1 epoch = num iterations / epoch = (total number of reads [aka nnz] ) / (num reads per batch)

We're basically calculating how many batches are within an epoch, in order to read through the entire dataset.

That being said -- I don't think you really need this. I think the current implementation is fine -- we just need a way to make the term epochs interpretable to the user.

mmvec/ALR.py Outdated
z = z + self.encoder_bias[X].reshape((*X.shape, 1))
y_pred = self.decoder(z)

forward_dist = Multinomial(total_count=0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest getting rid of the total_count=0 parameter -- we don't actually need it for log_prob.
And it may introduce a bug downstream (since the total_count isn't actually zero).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was actually a result of doing things in batch. We would run into an issue where the log_prob would indicate our calculation was out of the support of the distribution, because it had different counts sample to sample, so we solved it via this suggestion:

pytorch/pytorch#42407 (comment)

That said, looking at the documentation again, I wonder if we should be using logits instead of probs?

@ebolyen
Copy link
Member Author

ebolyen commented May 17, 2022

Thanks for the review @mortonjt!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants