Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible bug in latent vector loss calculation? #34

Closed
walmsley opened this issue Feb 15, 2021 · 14 comments
Closed

Possible bug in latent vector loss calculation? #34

walmsley opened this issue Feb 15, 2021 · 14 comments

Comments

@walmsley
Copy link
Contributor

walmsley commented Feb 15, 2021

I'm confused by this, and wondering if it could be a bug? It seems as though latents is of size (32,128), which means that for array in latents: iterates 32 times. However, the results from these iterations aren't stored anywhere, so they are at best a waste of time and at worst causing a miscalculation. Perhaps the intention was to accumulate the kurtoses and skews for each array in latents, and then computing lat_loss using all the accumulated values?

for array in latents:
    mean = torch.mean(array)
    diffs = array - mean
    var = torch.mean(torch.pow(diffs, 2.0))
    std = torch.pow(var, 0.5)
    zscores = diffs / std
    skews = torch.mean(torch.pow(zscores, 3.0))
    kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0

lat_loss = lat_loss + torch.abs(kurtoses) / num_latents + torch.abs(skews) / num_latents

Occurs at https://github.com/lucidrains/big-sleep/blob/main/big_sleep/big_sleep.py#L211

@walmsley
Copy link
Contributor Author

walmsley commented Feb 15, 2021

Digging around further in the loss calculation code, I'm also curious about two other points:

  1. Is this supposed to be torch.abs(torch.mean(latents, dim=1)).mean()? Otherwise I believe the outer mean is redundant because the inner mean has already reduced everything to a single value.

    torch.abs(torch.mean(latents)).mean() + \

  2. The class loss term processes soft_one_hot_classes of size (32,1000) but if I'm reading it correctly, the zero index [0] extracts only the first of 32 arrays from the output of the topk() operation in topk(soft_one_hot_classes, ..., dim=1)[0].
    a) Is it intentional that the loss only cares about the first array of soft_one_hot_classes and ignores the other 31?
    b) If so, then the topk operation looks like it could be made faster by only operating on topk(soft_one_hot_classes[0], ...) instead of topk(soft_one_hot_classes, ..., dim=1)[0]

    cls_loss = ((50 * torch.topk(soft_one_hot_classes, largest = False, dim = 1, k = 999)[0]) ** 2).mean()

@lucidrains
Copy link
Owner

@walmsley Hi Will! You caught a bug in regards to your first comment! I've fixed it in 0.5.1 🙏

For your second comment, I double checked the original colab from Ryan, and that's what he has there, so you may want to redirect your question to him (he originally devised this technique)

As for #2, topk actually returns a tuple, the first element being the actual topk values, and the second being the topk indices. The [0] is there to capture the values, not for referencing the first latent

@lucidrains
Copy link
Owner

lucidrains commented Feb 15, 2021

@walmsley going to go with your suggestion for #1, because I think you are right :) you should send Ryan a poke and let him know!

@walmsley
Copy link
Contributor Author

walmsley commented Feb 15, 2021

That fully addresses the above points, thanks!

BUT also perhaps most significantly, there is a deeper possible bug that I'm curious about. I was trying to understand why num_latents is set to 32 here:

class Latents(torch.nn.Module):
def __init__(
self,
num_latents = 32,

Diving deeper, it seems as though these 32 different vectors are only actually used within cond_vector here:

z = self.gen_z(cond_vector[0].unsqueeze(0))

and here:
z = layer(z, cond_vector[i+1].unsqueeze(0), truncation)

I debugged the loop surrounding line 520 above (using the current 512px BigGAN model), and found that the model only actually contains 15 layers; of those 15, only 14 of those layers are GenBlock layers which trigger line 520.

The result is that of the 32 latent vectors we create, only indices {0,1,2,3,4,5,6,7,8,10,11,12,13,14,15} are actually ever used. This wouldn't be a problem, except that the remaining 17 unused latent vectors may still be influencing the loss calculation. I'm still trying to work out whether their influence on the loss calculation is significant enough to merit fixing this, because the fix would be slightly nontrivial as it varies depending on the size of BigGAN model chosen.

@lucidrains
Copy link
Owner

@walmsley it's not a big deal, because there's only like less than 3 BigGAN models, and can just store a dictionary of the GAN size -> num latents somewhere

@walmsley
Copy link
Contributor Author

walmsley commented Feb 15, 2021

Okay so I currently believe that a solution could look like this:

  1. count the number of GenBlock layers by obtaining len(self.model.biggan.config.layers) (e.g. 14 layers), and add 1 to account for the 0th index being used for line 510. AKA num_latents = len(self.model.biggan.config.layers) + 1 which is typically 15.
  2. modify the loop around line 520 to avoid incrementing the index when we arrive at a non-GenBlock layer:
    for i, layer in enumerate(self.layers):
    if isinstance(layer, GenBlock):
    z = layer(z, cond_vector[i+1].unsqueeze(0), truncation)
    # z = layer(z, cond_vector[].unsqueeze(0), truncation)
    else:
    z = layer(z)

    ... changed to something like:
        next_available_latent_index = 1
        for layer in self.layers:
            if isinstance(layer, GenBlock):
                z = layer(z, cond_vector[next_available_latent_index].unsqueeze(0), truncation)
                next_available_latent_index += 1
            else:
                z = layer(z)

@htoyryla
Copy link

htoyryla commented Feb 15, 2021

I wonder how much it helps to calculate the skews and kurtoses for the relevant latents only, when the mean and std of all 32 latents are anyway used for the same loss here

lat_loss = torch.abs(1 - torch.std(latents, dim=1)).mean() + \
torch.abs(torch.mean(latents, dim = 1)).mean() + \
4 * torch.max(torch.square(latents).mean(), latent_thres)

Came to think about this while experimenting on image search from BigGAN using the same code as a basis. Perhaps we should instead simply dimension the Latents object with the correct amount to begin with.

Also we are still missing the proper way to accumulate the skews and kurtoses from each latent. Simply add up their absolute values (that's what I am doing right now) ? The inner part of the loop anyway is the same as here https://discuss.pytorch.org/t/statistics-for-whole-dataset/74511

Or, alternatively, skew and kurtosis might not be so important here either, if it had worked so well as it is. Anyway, I like to experiment. Changing the loss function a bit might not make it objectively better, but anyhow give visually different results (which is what matters to me, visual diversity to be explored).

@walmsley
Copy link
Contributor Author

walmsley commented Feb 15, 2021

Perhaps we should instead simply dimension the Latents object with the correct amount to begin with.

Yes, that was the intention, thanks for clarifying – num_latents = len(self.model.biggan.config.layers) + 1 should be used to fix num_latents from 32->15 at the source, not just for the sake of kurtoses/skews.

Also we are still missing the proper way to accumulate the skews and kurtoses from each latent

Note that this was fixed with release 0.5.1 as mentioned by @lucidrains

@htoyryla
Copy link

htoyryla commented Feb 15, 2021

Yes, that was the intention, thanks for clarifying – num_latents = len(self.model.biggan.config.layers) + 1 should be used to fix num_latents from 32->15 at the source, not just for the sake of kurtoses/skews.

To be precise, that must be done when the Latents object is instantiated, either by this default

num_latents = 32,
or here
self.latents = Latents(
max_classes = self.max_classes,
class_temperature = self.class_temperature
)

Or just use latents[:num_latents] in calculating the mean and the std.

@htoyryla
Copy link

htoyryla commented Feb 15, 2021

Also we are still missing the proper way to accumulate the skews and kurtoses from each latent

Note that this was fixed with release 0.5.1 as mentioned by @lucidrains

I don't think anything was done to this. As far as I know, @lucidrains stated the code comes directly from Ryan (so I sent him a link to this discussion).

What I have currently is (but can't see any major effect... it might be that the whole skewness etc factor is not so critical here?). My application is not big_sleep but image search which explains the small difference (latents vs. lats.normu).

skews = 0
    kurtoses = 0
    for array in lats.normu:
            mean = torch.mean(array)
            diffs = array - mean
            var = torch.mean(torch.pow(diffs, 2.0))
            std = torch.pow(var, 0.5)
            zscores = diffs / std
            skews += torch.abs(torch.mean(torch.pow(zscores, 3.0)))
            kurtoses +=  torch.abs(torch.mean(torch.pow(zscores, 4.0)) - 3.0)

@walmsley
Copy link
Contributor Author

@htoyryla the skew/kurtosis fix was subtle, it's just the change in indentation of line 211 here: 226b973#diff-a32d425a1d65b549cda9588699a004a9d283f46d0623256309606cc74f8d3dd8R211

@htoyryla
Copy link

htoyryla commented Feb 15, 2021

@htoyryla the skew/kurtosis fix was subtle, it's just the change in indentation of line 211 here: 226b973#diff-a32d425a1d65b549cda9588699a004a9d283f46d0623256309606cc74f8d3dd8R211

I see. I was only looking at the newest commit, with your name mentioned, did not realise this commit was also so recent. I therefore thought that the error Phil had fixed was the missing dim=1 :)

Looks like the code with indent is more or less equivalent to my solution.

@htoyryla
Copy link

htoyryla commented Feb 15, 2021

BTW, my interest in this loss calculation comes from my ongoing experiment to skip the one-hot coded class label in BigGAN here

embed = self.embeddings(class_label)
cond_vector = torch.cat((z, embed), dim=1)
and instead use a 128 element class embedding. I modified the loss to use similar code for the class embedding as for the latents, and then I noticed the loss was not working correctly at all, but pushing the values to a wrong direction.

@walmsley
Copy link
Contributor Author

walmsley commented Feb 16, 2021

Created new PR with proposed final fix @ #35

Overall status of 4 possible bugs mentioned in this issue:

  • kurtosis/skew accumulation (fixed in release 0.5.1)
  • latent loss mean(dim=1) (fixed in release 0.5.2)
  • class loss topk[0] (not a bug, no fix needed)
  • num_latents 32 -> 15 (fixed in release 0.5.3)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants