Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typical training results #61

Open
edend10 opened this issue Mar 1, 2021 · 17 comments
Open

Typical training results #61

edend10 opened this issue Mar 1, 2021 · 17 comments

Comments

@edend10
Copy link

edend10 commented Mar 1, 2021

Hi, great repo and thanks for sharing your work!

I'm trying your bird dataset example from the Colab with OpenAI's pretrained VAE. I wasn't able to get meaningful results so far on the Colab or on my own vm (Tesla T4 GPU).

13 epochs in of train_dalle.py and only seeing these kinds of results:
image

On my vm I ran $ python train_dalle.py --image_text_folder /parent/to/birds/dataset/directory without changing any of the code (only replaced wandb with another experiment tracking framework, but I doubt that should make a difference)

Should the bird dataset work better with the pretrained VAE? Can you share some results or common training parameters/times/number of epochs?

@lucidrains
Copy link
Owner

@edend10 Hi Eden! Thanks for trying out the repository! I may have found a bug with the pretrained VAE wrapper, fixed in the latest commit https://github.com/lucidrains/DALLE-pytorch/blob/0.2.2/dalle_pytorch/vae.py#L82 🙏 I'll be training this myself this week, and ironing out any remaining issues (other than data and scale of course)

@edend10
Copy link
Author

edend10 commented Mar 1, 2021

Thanks for the response @lucidrains !
Ohh interesting, I'll check out the changes and try it out. Will look out for more updates!

@CDitzel
Copy link

CDitzel commented Mar 2, 2021

what are those two mapping functions for anyway?

Are they just for transforming the pixel value range for the input data they just over at OpenAI?

@AlexanderRayCarlson
Copy link

Hello! Thank you for this excellent work. I seem to be getting something similar - abstract sorts of blue squares when training in the colab notebook. It looks like the package (0.2.2) is updated with the latest fix - is there anything else needed to do at the moment?

@awilson9
Copy link

awilson9 commented Mar 5, 2021

This is still happening for me as well on the pretrained VAE on 0.2.2

@afiaka87
Copy link
Contributor

afiaka87 commented Mar 9, 2021

This is an early output (2 epochs) from the new code that removes the normalization from train_dalle.py. Was that the necessary fix @lucidrains ?

DEPTH = 6
BATCH_SIZE = 8

media_images_image_1600_82d6d0f7

"a female mannequin"
mannequin

Much more cohesive and a much stronger start now. No strange blueness, at the very least.

@liuqk3
Copy link

liuqk3 commented Mar 10, 2021

Hi @afiaka87, Amazing results! Can you share more details about your configurations? such as the dataset, learning rate, lr scheduler, number of text and image (8192, right?) tokens? Thanks.

@afiaka87
Copy link
Contributor

afiaka87 commented Mar 10, 2021

Hi @afiaka87, Amazing results! Can you share more details about your configurations? such as the dataset, learning rate, lr scheduler, number of text and image (8192, right?) tokens? Thanks.

I should mention the dataset I'm using includes images released by OpenAI with their DALL-E. The mannequin image is not being generated from text alone, it's from an image text pair. Anyway, my point is that my dataset is bad and I'm mostly just messing around. It's probably the case that using images generated from DALL-E itself is bound to converge much quicker than usual.

I'm using the defaults in train_dalle.py except for the BATCH SIZE and DEPTH. Pretrained OpenAI VAE, top_k=0.9, and reversible=True. I tried mixing attention layers but it adds memory. (edit: I dont think it does actually. training with all four attn_types currently)

I'm working on creating a hyperparameter sweep with wandb currently. I think a learning rate of 2e-4 might be better for depth greater than 12 or so.

I still can't get a stable learning rate with 64 depth.

@afiaka87
Copy link
Contributor

afiaka87 commented Mar 10, 2021

Edit: You can find the whole training session here:

edit: edit: err here: https://wandb.ai/afiaka87/dalle-pytorch-openai-samples/reports/Training-on-OpenAI-DALL-E-Generated-Images--Vmlldzo1MTk2MjQ?accessToken=89u5e10c2oag5mlv46xm2sz6orkyqdlwjrsj8vd95oz8ke3ez6v8v2fh07klk6j1
I'm starting over because there have been updates to the main branch.

Original post:

"a professional high quality emoji of a spider starfish chimera . a spider imitating a starfish . a spider made of starfish . a professional emoji ."

starfish_spider_chimera

Left it running at 16 depth, 8 heads, batch size of 12 learning_rate=2e-4. The loss is going down at a steady consistent rate. (edit: just kidding! it seems to get stuck at around ~6.0 on this run. which seems high?)

DEPTH: 16
HEADS: 8
TOP_K: 0.85
EPOCHS: 27
SHUFFLE: True
DIM_HEAD: 64
MODEL_DIM: 512
BATCH_SIZE: 12
REVERSIBLE: true
TEXT_SEQ_LEN: 256
LEARNING_RATE: 0.0002
GRAD_CLIP_NORM: 0.5

@afiaka87
Copy link
Contributor

afiaka87 commented Mar 10, 2021

Edit:

Here, I used Weights & Biases to create a report. This link has all the images generated (every 100th iteration) for 27,831 total iterations

Edit: this one should work i think
https://wandb.ai/afiaka87/dalle-pytorch-openai-samples/reports/Training-on-OpenAI-DALL-E-Generated-Images--Vmlldzo1MTk2MjQ?accessToken=89u5e10c2oag5mlv46xm2sz6orkyqdlwjrsj8vd95oz8ke3ez6v8v2fh07klk6j1

@tommy19970714
Copy link

@afiaka87 Thank you for sharing your report of Weights & Biases!
But I can't see the report because its project is private.
Can you allow us to see it?
スクリーンショット 2021-03-11 17 55 43

@afiaka87
Copy link
Contributor

afiaka87 commented Mar 11, 2021

@afiaka87
Copy link
Contributor

afiaka87 commented Mar 11, 2021

Hi @afiaka87, Amazing results! Can you share more details about your configurations? such as the dataset, learning rate, lr scheduler, number of text and image (8192, right?) tokens? Thanks.

Just for more info on the dataset itself, it is roughly 1,100,000 256x256 image-text pairs that were generated by OpenAI's DALL-E. They presented roughly ~30k unique text prompts of which they posted the top 32 of 512 generations on https://openai.com/blog/dall-e/. Many images were corrupt, and not every prompt has a full 32 examples, but the total number of images winds up being about 1.1 million. If you look at many of the examples on that page, you'll see that DALL-E (in that form at least), can and will make mistakes. These mistakes are also in this dataset. Anyway I'm just messing around having fun training and what not. This is definitely not going to produce a good model or anything.

There are also a large number of images in the dataset which are intended to be used with the "mask" feature. I don't know if that's possible yet in DALLE-pytorch though. Anyway, that can't be helping much.

One valuable thing I've taken from this is that it seems to take at least ~2000 iterations with a batch size of 4 to approach any sort of coherent reproductions. This number specifically probably varies, but in terms of "knowing when to start over", I would say rougly 3000 steps might be a good soft target.

@tommy19970714
Copy link

Thank you for shareing your result!
I will refer your parameters.

@afiaka87
Copy link
Contributor

@tommy19970714

I did a hyperparameter sweep with weights and biases. Forty Eight 1200 iteration runs of dalle-pytorch while varying Learning Rate, Depth and Heads, (minimizing the total loss at the end of each run).

#84 (comment)

@afiaka87
Copy link
Contributor

Most important thing to note here is that the learning rate actually needs to go up to about 0.0005 when dealing with ~26-32 depth

@afiaka87
Copy link
Contributor

I've done a much longer training session on that same dataset here:

#86

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants