Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconstruction image is always a solid color #12

Closed
jpfeil opened this issue Nov 16, 2023 · 20 comments
Closed

Reconstruction image is always a solid color #12

jpfeil opened this issue Nov 16, 2023 · 20 comments

Comments

@jpfeil
Copy link
Contributor

jpfeil commented Nov 16, 2023

Hello,

I've been working on training this on the imagenet data, but I'm concerned I'm doing something wrong because the reconstructions are always a solid color. I haven't trained it very long ~1500 steps (batch size 10), but I just wanted to check if this is expected.

1300 steps:
image

1200 steps:
image

from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

tokenizer = VideoTokenizer(
    image_size = 256,
    codebook_size=1_024,
    use_gan=True,
    use_fsq=True,
    init_dim=128, 
    adversarial_loss_weight=0.1, # From the paper
    perceptual_loss_weight=0.1, # From the paper
    grad_penalty_loss_weight=10.0,
    lfq_entropy_loss_weight=0.3, # From the paper
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'linear_attend_space',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/imagenet/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 8,
    num_train_steps = 1_000_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=1e-4, # From the paper
    accelerate_kwargs={"split_batches": True, "mixed_precision": 'fp16'},
    random_split_seed=171,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={}
)

trainer.train()
@lucidrains
Copy link
Owner

lucidrains commented Nov 16, 2023

@jpfeil could you retry with fp32? and train until 5000 steps? also, grad accum of 4-6 is sufficient (32-64 effective batch size)

@lucidrains
Copy link
Owner

lucidrains commented Nov 16, 2023

@jpfeil also share your training curve, try out wandb's report feature for easy sharing

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 16, 2023

Thanks @lucidrains I'll let you know when the wandb report is ready.

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 17, 2023

@lucidrains This was run on 0.1.24, so I'm going to pull the latest version and retry. The loss was slowly improving, but around step 1000, the loss became nan. The only change I've made is I added a cosine schedule with warmup. I'm also still using bf16, so I'll change that in the next run.

https://api.wandb.ai/links/pfeiljx/p2x7x2x2

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 19, 2023

Hi @lucidrains

I ran it using fp32 and trained for 5000 steps, but I did not see any improvement.

https://api.wandb.ai/links/pfeiljx/8kqeyypi

Let me know if you have any suggestions.

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 20, 2023

@lucidrains I ran the fashion mnist data last night and the model was able to converge:

https://api.wandb.ai/links/pfeiljx/udspvdgu

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d %H:%M:%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    codebook_size=1_024,
    use_gan=True,
    use_fsq=True,
    init_dim=128, # From the paper,
    adversarial_loss_weight=0.1, # From the paper
    perceptual_loss_weight=0.1, # From the paper
    grad_penalty_loss_weight=10.0,
    lfq_entropy_loss_weight=0.3, # From the paper
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 5,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)


with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 {RUNTIME}'):
    trainer.train()

@lucidrains
Copy link
Owner

lucidrains commented Nov 20, 2023

@jpfeil @jacobpfeil i think this repository should support pretraining with 2d conv layers, and then a way to convert it to 3d for video. but let me meditate on the simplest way to achieve this

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 20, 2023

Thanks @lucidrains. Let me know if I can help run some tests. I have access to a few A100 GPUs.

@lucidrains
Copy link
Owner

@jpfeil sounds good

let me think about this for a few days or the code will come out wrong

measure twice cut once kinda thing

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 20, 2023

@lucidrains After looking at the FashionMNIST results, it looks like the discriminator collapsed to zero loss. So, I think the learning stopped prematurely. I'm also not getting good reconstructions.

sampled 17

For VQ-GAN, I've read that the autoencoder needs a couple epochs to generate good images before the discriminator starts. Is there a way to do that here?

@lucidrains
Copy link
Owner

lucidrains commented Nov 20, 2023

@jpfeil yea i could add that, but only if need be

what happens if you set adversarial_loss_weight to 0.

it really should converge for fashion mnist quite quickly, even without the GAN system

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 20, 2023

I get an assertion error because self.has_gan attribute gets set to False. Is it okay to override that assertion?

@lucidrains
Copy link
Owner

@jpfeil could you point to the line number?

could you also give 0.1.29 a quick try? may be a bug but not entirely sure

@lucidrains
Copy link
Owner

@jpfeil oh nvm, yes i see it. we should be able to turn off adversarial loss, let me fix

@lucidrains
Copy link
Owner

lucidrains commented Nov 20, 2023

@jpfeil try 0.1.31 with use_gan = False on the VideoTokenizer

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 21, 2023

Woops. My Tokenizer change wasn't saved. Running now...

@lucidrains
Copy link
Owner

@jpfeil give the imagenet run another try

there may have been a bug with how I zeroed the gradients a few patches ago

@jpfeil
Copy link
Contributor Author

jpfeil commented Nov 27, 2023

This is resolved for fashion mnist, but I haven't been able to run through enough imagenet data to see if it works for imagenet. I'm going to close this now and if it comes up again for imagenet, I'll open a new issue.

@jpfeil jpfeil closed this as completed Nov 27, 2023
@coolbunnyx
Copy link

coolbunnyx commented Dec 30, 2023

This is resolved for fashion mnist, but I haven't been able to run through enough imagenet data to see if it works for imagenet. I'm going to close this now and if it comes up again for imagenet, I'll open a new issue.

Hi @jpfeil Do you mind sharing how did you end up solving it? I run into the same issue #25

@jpfeil
Copy link
Contributor Author

jpfeil commented Jan 3, 2024

Hi @coolbunnyx,

Sorry for the delay. I think you already solved it, but I was able to get good reconstruction after training for longer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants