Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vqgan training #52

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
90aaf6d
Added basic idea
isamu-isozaki Apr 21, 2023
59efb88
Adding basic idea
isamu-isozaki Apr 21, 2023
cc2bc90
First idea for training loop
isamu-isozaki Apr 27, 2023
d6067a4
Added perliminary generation
isamu-isozaki Apr 27, 2023
bbfd757
Merge branch 'main' of https://github.com/huggingface/muse into vqgan…
isamu-isozaki Apr 27, 2023
508a00e
Added ema
isamu-isozaki Apr 28, 2023
f6dd2a1
Making config and removed wandb
isamu-isozaki Apr 30, 2023
08b2f12
Removed folder
isamu-isozaki Apr 30, 2023
fa4dc0d
Fixing configs
isamu-isozaki Apr 30, 2023
cfd0c19
Finished basic vqgan testing
isamu-isozaki Apr 30, 2023
4615048
Removed folders
isamu-isozaki May 8, 2023
026305a
Removed config
isamu-isozaki May 8, 2023
290f287
Adding discriminator warmup
isamu-isozaki May 8, 2023
9997d82
Starting adding projected gan tech
isamu-isozaki May 8, 2023
eb86603
Updated config
isamu-isozaki May 8, 2023
02e5c74
Update docs
isamu-isozaki May 8, 2023
a597b1a
Adding slurm file
isamu-isozaki May 10, 2023
105aac7
Updated config
isamu-isozaki May 15, 2023
6872bba
Moving tqdm to batch
isamu-isozaki May 15, 2023
38346ed
Tried updating webdb
isamu-isozaki May 15, 2023
4f296fc
Fixed config
isamu-isozaki May 15, 2023
75f8631
Fixing fmap
isamu-isozaki May 15, 2023
e29acf3
Fixed zero grad
isamu-isozaki May 15, 2023
7994f4e
Fixed discriminator training
isamu-isozaki May 15, 2023
c3e6139
Increased batch size
isamu-isozaki May 15, 2023
277edf1
Changed mixed precision
isamu-isozaki May 15, 2023
a7035ab
Fixed oom
isamu-isozaki May 15, 2023
2717aa9
Fixed debugging
isamu-isozaki May 15, 2023
d6cb053
sanity check
isamu-isozaki May 15, 2023
74eb612
sanity check
isamu-isozaki May 15, 2023
e800cae
Fixed training
isamu-isozaki May 15, 2023
ca9535e
sanity check
isamu-isozaki May 15, 2023
8e21e99
Fixing logs
isamu-isozaki May 15, 2023
3588307
Fixing logs
isamu-isozaki May 15, 2023
d88c402
Fixing logs
isamu-isozaki May 15, 2023
61d0d69
Fixing logs
isamu-isozaki May 15, 2023
c71255a
Added spectral norm f16 config
isamu-isozaki May 15, 2023
851f583
Fixing logs
isamu-isozaki May 15, 2023
1bb0a65
saving discriminator too
isamu-isozaki May 16, 2023
2f81055
Fixing logs
isamu-isozaki May 16, 2023
9d9b227
Made training distributed
isamu-isozaki May 16, 2023
c23ae1c
Properly running vqgan training
isamu-isozaki May 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Byte-compiled / optimized / DLL files
output.jpg
imagenet-vqgan-training
wandb
__pycache__/
*.py[cod]
*$py.class
Expand Down
95 changes: 95 additions & 0 deletions configs/imagenet_vqgan_training.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
wandb:
entity: null

experiment:
project: "muse"
name: "imagenet-vqgan-training"
output_dir: "imagenet-vqgan-training"
max_train_examples: 1281167 # total number of imagenet examples
max_eval_examples: 12800
save_every: 1000
eval_every: 1000
generate_every: 1000
log_every: 30
log_grad_norm_every: 500
resume_from_checkpoint: False
resume_lr_scheduler: True

model:
vq_model:
type: "taming_vqgan"
pretrained: "openMUSE/vqgan-f16-8192-laion"
gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
type: "classification"
params:
train_shards_path_or_url: "pipe:aws s3 cp s3://s-laion/muse-imagenet/imagenet-train-{000000..000320}.tar -"
eval_shards_path_or_url: "pipe:aws s3 cp s3://s-laion/muse-imagenet/imagenet-val-{000000..000012}.tar -"
imagenet_class_mapping_path: "/fsx/Isamu/data/imagenet-class-mapping.json"
dataset.params.validation_prompts_file: null
batch_size: ${training.batch_size}
shuffle_buffer_size: 1000
num_workers: 4
resolution: 256
pin_memory: True
persistent_workers: True
preprocessing:
max_seq_length: 16
resolution: 256
center_crop: True
random_flip: False
discriminator:
dim: 64
channels: 3
groups: 32
init_kernel_size: 5
kernel_size: 3
act: "silu"
discr_layers: 4
optimizer:
name: adamw
params: # default adamw params
learning_rate: 1e-4
scale_lr: False # scale learning rate by total batch size
beta1: 0.9
beta2: 0.999
weight_decay: 0.01
epsilon: 1e-8
discr_learning_rate: 1e-4


lr_scheduler:
scheduler: "constant_with_warmup"
params:
learning_rate: ${optimizer.params.learning_rate}
warmup_steps: 1000


training:
gradient_accumulation_steps: 2
batch_size: 16
mixed_precision: "bf16"
enable_tf32: True
use_ema: False
seed: 9345104
max_train_steps: 200000
overfit_one_batch: False
cond_dropout_prob: 0.1
min_masking_rate: 0.0
label_smoothing: 0.0
max_grad_norm: null
guidance_scale: 2.0
generation_timesteps: 8
# related to vae code sampling
use_soft_code_target: False
use_stochastic_code: False
soft_code_temp: 1.0
timm_discriminator_backend: "vgg19"
timm_disc_layers: "features|pre_logits|head"
timm_discr_offset: 0
vae_loss: "l2"
num_validation_log: 4
discriminator_warmup: 10000
95 changes: 95 additions & 0 deletions configs/imagenet_vqgan_training_jewels.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
wandb:
entity: null

experiment:
project: "muse"
name: "imagenet-vqgan-training"
output_dir: "imagenet-vqgan-training"
max_train_examples: 1281167 # total number of imagenet examples
max_eval_examples: 12800
save_every: 1000
eval_every: 1000
generate_every: 1000
log_every: 30
log_grad_norm_every: 500
resume_from_checkpoint: False
resume_lr_scheduler: True

model:
vq_model:
type: "taming_vqgan"
pretrained: "openMUSE/vqgan-f16-8192-laion"
gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
type: "classification"
params:
train_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-train-{000000..000320}.tar"
eval_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-val-{000000..000012}.tar"
imagenet_class_mapping_path: "/p/scratch/ccstdl/muse/imagenet-class-mapping.json"
dataset.params.validation_prompts_file: null
batch_size: ${training.batch_size}
shuffle_buffer_size: 1000
num_workers: 4
resolution: 256
pin_memory: True
persistent_workers: True
preprocessing:
max_seq_length: 16
resolution: 256
center_crop: True
random_flip: False
discriminator:
dim: 64
channels: 3
groups: 32
init_kernel_size: 5
kernel_size: 3
act: "silu"
discr_layers: 4
optimizer:
name: adamw
params: # default adamw params
learning_rate: 1e-4
scale_lr: False # scale learning rate by total batch size
beta1: 0.9
beta2: 0.999
weight_decay: 0.01
epsilon: 1e-8
discr_learning_rate: 1e-4


lr_scheduler:
scheduler: "constant_with_warmup"
params:
learning_rate: ${optimizer.params.learning_rate}
warmup_steps: 1000


training:
gradient_accumulation_steps: 2
batch_size: 16
mixed_precision: "bf16"
enable_tf32: True
use_ema: False
seed: 9345104
max_train_steps: 200000
overfit_one_batch: False
cond_dropout_prob: 0.1
min_masking_rate: 0.0
label_smoothing: 0.0
max_grad_norm: null
guidance_scale: 2.0
generation_timesteps: 8
# related to vae code sampling
use_soft_code_target: False
use_stochastic_code: False
soft_code_temp: 1.0
timm_discriminator_backend: "vgg19"
timm_disc_layers: "features|pre_logits|head"
timm_discr_offset: 0
vae_loss: "l2"
num_validation_log: 4
discriminator_warmup: 10000
95 changes: 95 additions & 0 deletions configs/imagenet_vqgan_training_jewels_f16_vqgan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
wandb:
entity: null

experiment:
project: "muse"
name: "imagenet-vqgan-training"
output_dir: "imagenet-vqgan-training"
max_train_examples: 1281167 # total number of imagenet examples
max_eval_examples: 12800
save_every: 1000
eval_every: 1000
generate_every: 1000
log_every: 30
log_grad_norm_every: 500
resume_from_checkpoint: False
resume_lr_scheduler: True

model:
vq_model:
type: "taming_vqgan"
pretrained: "vqgan-f16-8192-laion-movq"
gradient_checkpointing: True
enable_xformers_memory_efficient_attention: True


dataset:
type: "classification"
params:
train_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-train-{000000..000320}.tar"
eval_shards_path_or_url: "/p/scratch/ccstdl/muse/muse-imagenet/imagenet-val-{000000..000012}.tar"
imagenet_class_mapping_path: "/p/scratch/ccstdl/muse/imagenet-class-mapping.json"
dataset.params.validation_prompts_file: null
batch_size: ${training.batch_size}
shuffle_buffer_size: 1000
num_workers: 4
resolution: 256
pin_memory: True
persistent_workers: True
preprocessing:
max_seq_length: 16
resolution: 256
center_crop: True
random_flip: False
discriminator:
dim: 64
channels: 3
groups: 32
init_kernel_size: 5
kernel_size: 3
act: "silu"
discr_layers: 4
optimizer:
name: adamw
params: # default adamw params
learning_rate: 1e-4
scale_lr: False # scale learning rate by total batch size
beta1: 0.9
beta2: 0.999
weight_decay: 0.01
epsilon: 1e-8
discr_learning_rate: 1e-4


lr_scheduler:
scheduler: "constant_with_warmup"
params:
learning_rate: ${optimizer.params.learning_rate}
warmup_steps: 1000


training:
gradient_accumulation_steps: 2
batch_size: 16
mixed_precision: "bf16"
enable_tf32: True
use_ema: False
seed: 9345104
max_train_steps: 200000
overfit_one_batch: False
cond_dropout_prob: 0.1
min_masking_rate: 0.0
label_smoothing: 0.0
max_grad_norm: null
guidance_scale: 2.0
generation_timesteps: 8
# related to vae code sampling
use_soft_code_target: False
use_stochastic_code: False
soft_code_temp: 1.0
timm_discriminator_backend: "vgg19"
timm_disc_layers: "features|pre_logits|head"
timm_discr_offset: 0
vae_loss: "l2"
num_validation_log: 4
discriminator_warmup: 10000
83 changes: 83 additions & 0 deletions slurm_scripts/imagenet_vqgan.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/bin/bash
#SBATCH --job-name=vqgan_testing
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=96
#SBATCH --gres=gpu:8
#SBATCH --exclusive
#SBATCH --partition=g40
#SBATCH --output=/fsx/Isamu/logs/maskgit-imagenet/%x-%j.out

set -x -e

echo "START TIME: $(date)"

MUSE_REPO=/fsx/Isamu/open-muse
OUTPUT_DIR=/fsx/Isamu
LOG_PATH=$OUTPUT_DIR/main_log.txt

mkdir -p $OUTPUT_DIR
touch $LOG_PATH
pushd $MUSE_REPO

CMD=" \
training/train_vqgan.py config=configs/imagenet_vqgan_training.yaml \
wandb.entity=isamu \
experiment.name=$(basename $OUTPUT_DIR) \
experiment.output_dir=$OUTPUT_DIR \
training.seed=9345104 \
training.batch_size=160 \
"

GPUS_PER_NODE=8
NNODES=$SLURM_NNODES

# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000

export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
--rdzv_backend c10d \
--max_restarts 0 \
--tee 3 \
"

echo $CMD

# hide duplicated errors using this hack - will be properly fixed in pt-1.12
# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json

# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1

# AWS specific
export NCCL_PROTO=simple
export RDMAV_FORK_SAFE=1
export FI_EFA_FORK_SAFE=1
export FI_EFA_USE_DEVICE_RDMA=1
export FI_PROVIDER=efa
export FI_LOG_LEVEL=1
export NCCL_IB_DISABLE=1
export NCCL_SOCKET_IFNAME=ens


# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
"

# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH

echo "END TIME: $(date)"
Loading