Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue about generated images #6

Open
wzmsltw opened this issue Feb 25, 2022 · 45 comments
Open

Issue about generated images #6

wzmsltw opened this issue Feb 25, 2022 · 45 comments

Comments

@wzmsltw
Copy link

wzmsltw commented Feb 25, 2022

Hi

I have also tried to re-produce the MaskGIT recently. After training 150 epoch on ImageNet, our model can only achieve 8.4% accuracy on token classification. During sampling, we find our model will generate monochrome image (nearly white). Do you meet similar problem?

@dome272
Copy link
Owner

dome272 commented Feb 25, 2022

Unfortunately, I could not even train a model yet, because my compute ressources are allocated otherwise at the moment. Did you use your own codebase? If so, did you publish it somewhere?

@pabloppp
Copy link

@wzmsltw I'm also building a custom model inspired by this paper on the CelebA dataset, and I found something similar happens. I think in my case it's still early in the training, and I get accuracies of around ~40% which make sense since during the training the masking follows a cosine scheduling (as the paper says) and the AUC of the cosine function is around 0.363, so 40% means it works slightly better than just guessing the unmasked tokens.
This actually makes sense, since we do not want a perfect 100% that will mean that the model overfits, we want to be able to sample several possible reconstructions from a masked image, so that doesn't worry me too much...

What I find strange is what happens during the sampling.
At the first step, the model generates something pretty promising, while being very noisy:
image

And as the sampling process continues, it would seem like the model is starting to generate something that looks like a face :D
image

BUT at some point, the sampling makes the face start to fade away...
image

And we end up with an almost empty image in most cases :/ looks like as I start to decrease the temperature during sampling, the samples start to collapse to an empty background or idk...
image

Maybe something similar is happening to you... I don't know if this will be fixed with more training, or it's a problem in the sampling procedure... wdyt?

@dome272
Copy link
Owner

dome272 commented Mar 20, 2022

Did any of you find a solution? @pabloppp @wzmsltw

@pabloppp
Copy link

Unfortunately no. I have the feeling that the issue is with the sampling method, but it might also be related with how the tokens are masked during training :/

@dome272
Copy link
Owner

dome272 commented Mar 20, 2022

May I ask if you used my implementation or your own? If you used mine then it might also be something implementation specific. However if we both get the same white images with different implementations, then yes there might be something wrong with the overall sampling strategy.

@pabloppp
Copy link

I used my own implementation

@LeeDoYup
Copy link

I cannot still reproduce the results in the paper after 300 epochs of training on ImageNet. However, I fine that the temperature annealing is a key for the performance of MaskGIT and diversity of generated samples. When I have not used the temperature annealing, I got 32.87 FID. However, when I have used the temperature annealing, which linearly decrease the temperature of logits from 3.0 into 1.0, I got 20.26.

When I train MaskGIT on FFHQ, very simple (but high quality) images are generated. However, due to the simplicity of generated images, the recall of the trained model is very low and FID is over 100.

I think that many tricks are required to train MaskGit, but the details are not described in the paper. Especially, temperature annealing is very very important trick to decrease FID, but the authors did not describe the details.. How can I believe the scores in the original paper...

@dome272
Copy link
Owner

dome272 commented Mar 20, 2022

Oh that is really interesting with the temperature annealing. I completely over-read it.
Without the annealing sampling looks like this:
image
And with annealing going from 3 to 2 it looks like this:
image

(Trained just for 28 epochs on landscape flickr dataset)

@pabloppp
Copy link

The samples I shared used temperature annealing as well, but I still don't get very good results.

@dome272
Copy link
Owner

dome272 commented Mar 21, 2022

@pabloppp @LeeDoYup Maybe if you are interested we can make a group on discord and report new findings. I would also be interested in your Transformer implementation. I guess mine is so simplistic. So if you are up to it you can add me on discord: dome#8231

Also the authors are referencing BEiT which uses a slightly different way of training. Even though the authors clearly described their way of training, maybe using the approach from BEiT could result in improvements. Have you tried anything like this?

@pabloppp
Copy link

I have not tried anything like BEiT, in fact, my architecture is pretty different from the one proposed in the paper. What I tried to follow as close as possible were the losses, training schedules & sampling schedules.

@LeeDoYup
Copy link

I do not have discord account, so i will try to create soon.
@pabloppp What was your setting of temperature annealing? In my case, on FFHQ, start_temperature is 3.0 or 5.0 and end_temperature is 1.0

@pabloppp
Copy link

I have it as a parameter in my sampling function, I also tried different schedules for the annealing (linear, cosine, etc...), with very similar results, so I don't exactly remember the parameters used in the samples that I shared above, but for example, in this samples, I used a cosine decaying temperature (same % as the sampling token numbers) going from 1.5 to 0.3, for 16 steps.

Here I'm showing the first output, half-schedule and final output
image
image
image

@pabloppp
Copy link

It shouldn't be relevant, but my model is conditional, so I add an identity embedding to the input. My goal was to be able to control to some degree the generation, so I can ask the model to generate a specific face instead of just random + help the model since conditional generation is usually way easier for generative models.

The model seems to be able to use that information up to some point, like generating male/female faces depending on the reference image but does a lot of random generation as well.

(Bottom image is a conditional sampling from scratch)
image

@pabloppp
Copy link

Hello! Small update: I just tried adding typical filtering to the sampling code, and the results are still far from perfect, but I managed to pass from a very high % of just plain colored images to a considerable % of face-ish results :D

Here's an example without typical filtering:
image

And here a couple of examples with typical filtering (with a mass of 0.2):
image
image

For the filtering I just adapted the code from the official repo:
https://github.com/cimeister/typical-sampling/blob/3e676cfd88fa2e6a24f2bdc6f9f07fddb87827c2/src/transformers/generation_logits_process.py#L242-L272

Seems like, although the 'typical filter' is made to try to follow some rules about how language works, by allowing the model to pick from a large number of options when the expected information is high while reducing the pool of options when the expected information is low, it seems to also benefit image generation. I think it might even be related to the non-sequential nature of the sampling, so at the beginning when sampling the first pixels, the expected information is pretty high, so the model can pick a wider variety of options, while as the image starts taking shape the options are reduced since we already have a general sketch of the image... Or idk, it might be something completely different XD

Anyway, hopefully this is useful for someone, and maybe we could even reach @cimeister to ask if they thought of this for image generation 🤔

@LeeDoYup
Copy link

@pabloppp Oh, did you use the typical sampling in the process of multinomial sampling to predict the code of each position? I think it would be help increasing diversity, since the typical sampling is known to resolve de-generation problem in NLG.

@pabloppp
Copy link

pabloppp commented Mar 28, 2022

Yes, basically before calling multinomial sampling I do what the TypicalLogitsWarper function does to set the logits of the filtered tokens to -inf so the multinomial only samples from the filtered pool. I also keep the temperature decay and the sampling schedule for the number of tokens sampled each step untouched.

@cimeister
Copy link

Wow very cool! Thanks for sharing, @pabloppp. We hadn't tried typical sampling yet for image generation but it seems like a promising direction!

@LeeDoYup
Copy link

The paper describes as follow.

In practice, the masking tokens are randomly sampled with temperature annealing to encourage more diversity, and we will discuss its effect in 4.4.

Here, I am confusing whether they use temperature annealing (TA) to randomly select the masking position, or use TA in the multinomial sampling in each position.
When I randomly select the masking position, I can got a large performance improvement, although the performance in the paper is not reproduced.

@pabloppp
Copy link

@LeeDoYup I'm pretty sure temperature is applied before softmax logits, thus affecting the multinomial sampling 🤔 but things that you mention are correlated: you change the temperature, so the probability of sampling some tokens varies, then you apply the multinomial and keep only a number of tokens based on their score following the cosine schedule.

@LeeDoYup
Copy link

@pabloppp When I logically think about the mask selection based on the algorithm, I also agree that the TA is used before softmax logits. However, when I only read the sentence above, the sentence means that they made a randomness on "mask selection" not on token sampling. So I am very confused.

When I use the random masking strategy, the performance on ImageNet is much improved. For examle, when n_decoding_step=8, I got precision=0.63 and recall=0.36 with linear temperatue annealing (5.0 => 1.0). However, I got precision=0.64 and recall=0.58 when I randomly select the unmasking tokens, and fix the temperature=0.8 over all decoding steps. I think, the hyper-parameter tuning of MaskGIT is very exhaustive ..

@dome272
Copy link
Owner

dome272 commented Mar 29, 2022

@LeeDoYup can you show some pseudo-code example for both of the cases you describe above?
Im a bit confused which is which. What excatly do you mean with random masking strategy?
What I have implemented is that I choose the indices which have the highest probability. Are you just randomly selecting these tokens?

@LeeDoYup
Copy link

@dome272 I use the random masking strategy as follow:

if strategy == 'random':
    candidates = masked_idxs_sorted
    subset = torch.randperm(candidates.shape[0])[:n_newly_unmasked]
    newly_unmasked_idxs = masked_idxs_sorted[subset]

masked_idxs_sorted => the indexes, which are masked at each decoding time, sorted with their confidence
n_newly_unmasked=> # of tokens to additionally unmask at each decoding time.

That is, I randomly select the positions of unmasked tokens.

@pabloppp
Copy link

pabloppp commented Mar 29, 2022

In some way, it makes sense to choose them randomly since during training you're masking them randomly, so the model is used to having to reconstruct random missing tokens, not necessarily start with the highest scores and end with the ones with lowest score. But that said, I the paper they very explicitly say that they sample the highest scored tokens.

Captura de pantalla 2022-03-29 a las 10 43 29

At each iteration, the model predicts all tokens simultaneously but only keeps the most confident ones.

I'm pretty sure when they say In practice, the masking tokens are randomly sampled with temperature annealing to encourage more diversity they mean that they don't just pick the token with the highest score, but they sample from a multinomial distribution adjusted with the temperature. This multinomial sampling introduces a randomness factor (the higher the temperature, the more randomness).

Anyway, if you've found that sampling randomly instead of taking the highest scores helps, it's worth a try 🙇

BTW what are you doing exactly to get recall & precision from a random sampling? 🤔 What do you compare your output to in order to get metrics?

@LeeDoYup
Copy link

@pabloppp I totally agree with you. However, the most problematic fact is that the hyper-parameters of TA is not described in the paper and hard to reproduce the results.

When I evaluate the recall & precision on ImageNet, I generated 50K samples and use the protocol in this repository.

@dome272
Copy link
Owner

dome272 commented Mar 29, 2022

@pabloppp
Could you show a code sample where you put the typical sampling?

  logits /= temperature
  filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
  probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
  sample = torch.multinomial(probs, 1)

Thats the normal topk_topp sampling. At which point do you call the TypicalLogitsWarper and with which input?

@pabloppp
Copy link

pabloppp commented Mar 29, 2022

I don't think you're doing it right. You're supposed to first sample for every masked token, then pick the topK with highest scores, since otherwise, you don't really know the score of the token you sampled.

I do not use directly the TypicalLogitsWarper class, but I use the same implementation. This is the core of my sampling implementation.

logits, _ = self(x, c, mask)
probs = logits.div(temp)
probs_flat = probs.permute(0, 2, 3, 1).reshape(-1, probs.size(1))
if typical_filtering:
    probs_flat_norm = torch.nn.functional.log_softmax(probs_flat, dim=-1)
    probs_flat_norm_p = torch.exp(probs_flat_norm)
    entropy = -(probs_flat_norm * probs_flat_norm_p).nansum(-1, keepdim=True)

    probs_flat_shifted = torch.abs((-probs_flat_norm) - entropy)
    probs_flat_sorted, probs_flat_indices = torch.sort(probs_flat_shifted, descending=False)
    probs_flat_cumsum = probs_flat.gather(-1, probs_flat_indices).softmax(dim=-1).cumsum(dim=-1)

    last_ind = (probs_flat_cumsum < typical_mass).sum(dim=-1)
    sorted_indices_to_remove = probs_flat_sorted > probs_flat_sorted.gather(1, last_ind.view(-1, 1))
    if typical_min_tokens > 1:
        sorted_indices_to_remove[..., :typical_min_tokens] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(1, probs_flat_indices, sorted_indices_to_remove)
    probs_flat = probs_flat.masked_fill(indices_to_remove, -float("Inf"))
probs_flat = probs_flat.softmax(dim=-1)
sample_indices = torch.multinomial(probs_flat, num_samples=1)
sample_scores = torch.gather(probs_flat, 1, sample_indices)

@dome272
Copy link
Owner

dome272 commented Mar 29, 2022

hey guys. today I reached out the authors if they would help us in our problem and the first author replied to me and said that they are planning to release the code next week (or so).
And since typical sampling is probably not used in their paper, using it will probably give an even higher boost in performance.

@pabloppp
Copy link

pabloppp commented Apr 6, 2022

Guess the official repo is out: https://github.com/google-research/maskgit (although it seems to be in JAX)
Let's find out (and share here please 🙏 ) what was missing in our implementations 🙇

@dome272
Copy link
Owner

dome272 commented Apr 6, 2022

Yea please report if anyone finds the cliffhanger....

@pabloppp
Copy link

pabloppp commented Apr 6, 2022

Small finding:
Seems like @LeeDoYup was on the right track
https://github.com/google-research/maskgit/blob/cf615d448642942ddebaa7af1d1ed06a05720a91/maskgit/libml/parallel_decode.py#L127
They sample completely randomly and then choose the ones that have the highest probability, but they add noise to the probabilities, which is linearly decayed.

So, step 0 just basically samples randomly a token, then the next step is also random but less random, etc... :/

I guess we should try this, but it seems like a lot of randomness XD

@LeeDoYup
Copy link

LeeDoYup commented Apr 7, 2022

Yes, when I see https://github.com/google-research/maskgit/blob/cf615d448642942ddebaa7af1d1ed06a05720a91/maskgit/libml/parallel_decode.py#L49-L56, i conclude that the paper was wrong. Mixing randomness is the key of the algorithm.....

@dome272
Copy link
Owner

dome272 commented Apr 7, 2022

If someone of you translated the sampling code to pytorch, could you post it here?

@pabloppp
Copy link

pabloppp commented Apr 17, 2022

Hi, just wanted to share some small update, and see if anyone here was getting new interesting results.
My model conditioned in facial identity keeps getting better the longer I train it. It's still veeeeery far from what it should be able to produce based on the paper, but it feels like the training is very very very slow and requires a lot of iterations (I'm doing my experiment on Google Colab soooo... not ideal XD). I tried to replicate the official sampling schedule as close as possible, but ended up doing some tweaks because I feel they work slightly better in my case:

  • I do both softmax temperature annealing (from 2 to 1) and mask temperature annealing (from 1 to 0, as in the official repository).
  • Since my model uses an identity embedding as input, I zero it 10% of the iterations and then use classifier-free guidance during sampling by doing the following: I predict the logits conditioned with the id embedding, I predict the logits with the zeroed embedding, and then do conditional_logits * 2 - zeroed_logits, then use those logits for the rest of the process (I think that's why my initial temperature has to be 2 to re-scale things up when doing the softmax).
  • I noticed that doing typical sampling gives better detail in the face, but much flatter hairs/backgrounds, I have it as a parameter of my model and keep testing stuff, but for now I defaulted to the above.

image

If I sample multiple images using the same id embedding I get something like this, the identity is not very well preserved, but there are clearly common traits that the model is able to reproduce.
Captura de pantalla 2022-04-17 a las 22 03 53

I started doing some tests on a V100 using a much larger, more diverse unlabeled dataset of nature images, in a sort of autoencoder-ish way: image embedding in -> image out, but I just managed to get noise. Maybe it's just a matter of training longer on more powerful machines? Maybe there's some way to accelerate the training to avoid having to do the equivalent of 300 epochs on Imagenet as they claim in the paper? Idk...

Has anyone else managed to get something?

@GuoxingY
Copy link

Thanks for your share. I have also tried to train a transformer based model on COCO dataset for image generation, but got worse results. Could you share some nature images generated by your model trained on the dataset of nature images? I wonder if the training iteration is the key point to train a well model or I miss some details in my implementation.

@pabloppp
Copy link

Sure. Depending on how I do the sampling (more/less temperature), I get stuff like this two things (this is about 900k iterations of batch size 6, on a dataset with 270k images)
Captura de pantalla 2022-04-18 a las 23 16 33
image

I feel like the model is starting to learn what tokens are more common in this sort of image, but its still very random, and produces a "not completely random" noise.

With 300 epochs of Imagenet, the model sees 4200000000 images, so my training is only about 0.12% of what they did in the paper, and still took me several days on a V100 😓

@LeeDoYup
Copy link

I finish to train 200M params of model on ImageNet during 300 epochs, but when I use the released technique, I got FID=21. When I use temperature scale in predicting tokens (not mask), I got FID=11~12, which is not reproduced result.

By the way, I think they do not use temperature annealing to randomly select the position of un-masking, since the below code fixes the temperature parameters as a scalar (=4.5). Is it right...?
https://github.com/google-research/maskgit/blob/cf615d448642942ddebaa7af1d1ed06a05720a91/maskgit/libml/parallel_decode.py#L158-L159

@pabloppp
Copy link

It's 4.5 * (1 - ratio) and ratio goes up from 0 to 1 with the sampling step, so the temperature gets annealed to 0

@LeeDoYup
Copy link

@pabloppp Oh, thank you I will try it !

@GuoxingY
Copy link

@pabloppp so maybe training the model longer can get a better results like @LeeDoYup shared? Although the quantitative result is not as good as the paper stated, but I think the quality of generated images with FID=21 is much better than images we got. Could you share any generated images here @LeeDoYup ?

@LeeDoYup
Copy link

@GuoxingY The images are generated images of ImageNet. I will share in this thread soon.

@dome272
Copy link
Owner

dome272 commented Apr 19, 2022

I trained the the described transformer (172M params) for 1000 epochs on a dataset with 8k landscape images with batch_size of 100.
Loss:
image

Samples:
(1. original, 2. reconstructed, 3. Inpainted bottom half, 4. New sampled image)
image
image
image
image

More samples:
image
image
image
image
image
image
image
image

Ill update my code and upload some checkpoints soon if anyone is interested. Tried to follow the paper as close as possible.
Note: the samples are somewhat cherrypicked.

@pabloppp
Copy link

How much of the image do you mask for the reconstructed image?

@dome272
Copy link
Owner

dome272 commented Apr 20, 2022

How much of the image do you mask for the reconstructed image?

The formulation might have been a bit misleading in that context. The reconstruction is just encoding and decoding the image and has nothing to do with the transformer. I just put it in for my own better understanding.

@zhuqiangLu
Copy link

I trained the the described transformer (172M params) for 1000 epochs on a dataset with 8k landscape images with batch_size of 100. Loss: image

Samples: (1. original, 2. reconstructed, 3. Inpainted bottom half, 4. New sampled image) image image image image

More samples: image image image image image image image image

Ill update my code and upload some checkpoints soon if anyone is interested. Tried to follow the paper as close as possible. Note: the samples are somewhat cherrypicked.

Hi, has the code been updated?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants