Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experiment tracker agnostic #73

Closed
lucidrains opened this issue May 7, 2022 · 5 comments
Closed

experiment tracker agnostic #73

lucidrains opened this issue May 7, 2022 · 5 comments

Comments

@lucidrains
Copy link
Owner

ran into a bunch of issues with wandb and distributed training for https://github.com/lucidrains/dalle-pytorch, so we should refactor any training script to be experiment tracker agnostic this time around

@lucidrains
Copy link
Owner Author

lucidrains commented May 7, 2022

a good template would be how the logs are accumulated here https://github.com/lucidrains/tf-bind-transformer/blob/main/tf_bind_transformer/training_utils.py additional helper functions can be brought in for "maybe transforms" on certain keys in the log

@rom1504
Copy link
Collaborator

rom1504 commented May 9, 2022

What's the overall idea about being experimenter tracker agnostic?
Do you want to support other trackers or do you mostly want to be able to disable it?

Regarding distributed training, i figure there's 2 things to support:

  1. Logging only on node=0
  2. not even logging on node 0, but let nodes report through some custom way (eg the disk), so some other node (eg a login node) can retrieve that information and log to the tracker (this is for example needed on juwels where compute nodes don't have access to the internet)

How would you want to implement this ? What's the main goal ?

@lucidrains
Copy link
Owner Author

@rom1504 both support other trackers and be able to disable. i've done this successfully for some other projects by now - here is an example of what i have for https://github.com/lucidrains/video-diffusion-pytorch

import wandb
wandb.init(project = 'video-diffusion')
wandb.run.name = 'resnet'
wandb.run.save()

trainer = Trainer(
    diffusion,
    '/home/phil/dl/nuwa-pytorch/gif-moving-mnist/',
    results_folder = './results-new-focus-present',
    train_batch_size = 4,
    train_lr = 2e-5,
    save_and_sample_every = 1000,
    max_grad_norm = 0.5,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 8,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = True                        # turn on mixed precision
)

trainer.load(-1)

def log_fn(log):
    if 'sample' in log:
        log['sample'] = wandb.Video(log['sample'])
    wandb.log(log)

trainer.train(log_fn = log_fn, prob_focus_present = 0.)

@lucidrains
Copy link
Owner Author

the log_fn can be made more composable for sure, as you may want to exclude certain keys from being logged, wrap other ones, derive other keys from available ones in the set etc

@lucidrains lucidrains mentioned this issue May 11, 2022
@lucidrains
Copy link
Owner Author

started 89de5af

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants