Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimize how often we load args.resume #71

Open
achalddave opened this issue Nov 6, 2023 · 1 comment
Open

Minimize how often we load args.resume #71

achalddave opened this issue Nov 6, 2023 · 1 comment
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@achalddave
Copy link
Collaborator

Currently, we load args.resume potentially up to 3 times. This can be pretty slow for big models, and we should avoid re-loading it in these spots:

open_lm/open_lm/main.py

Lines 110 to 156 in 97d0a4a

def load_model(args, model):
checkpoint = pt_load(args.resume, map_location="cpu")
if "epoch" in checkpoint:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint["epoch"]
sd = checkpoint["state_dict"]
if next(iter(sd.items()))[0].startswith("module"):
sd = {k[len("module.") :]: v for k, v in sd.items()}
model.load_state_dict(sd)
logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})")
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")
return start_epoch
def load_optimizer(args, model, optimizer, scaler):
potential_checkpoint = args.resume.replace("epoch_", "optimizer_")
if check_exists(potential_checkpoint):
checkpoint = pt_load(potential_checkpoint, map_location="cpu")
else:
checkpoint = pt_load(args.resume, map_location="cpu")
if "optimizer" in checkpoint:
if optimizer is not None:
osd = checkpoint["optimizer"]
if args.fsdp:
osd = FSDP.optim_state_dict_to_load(
model=model, optim=optimizer, optim_state_dict=osd
)
optimizer.load_state_dict(osd)
logging.info(f"=> resuming optimizer")
if scaler is not None and "scaler" in checkpoint:
scaler.load_state_dict(checkpoint["scaler"])
else:
logging.info(f"=> WARNING: not resuming optimizer.")
def load_data_chunks(args):
checkpoint = pt_load(args.resume, map_location="cpu")
if "next_chunk" in checkpoint and "samples_seen" in checkpoint:
return checkpoint["next_chunk"], checkpoint["samples_seen"]
else:
logging.info(
f"=> WARNING: tried to resume a checkpoint without data chunk info. Assuming next_chunk = 0."
)
return 0, 0

@achalddave achalddave added enhancement New feature or request good first issue Good for newcomers labels Nov 6, 2023
@achalddave
Copy link
Collaborator Author

In the current iteration, I believe our main offender is: https://github.com/mlfoundations/open_lm/blob/main/open_lm/main.py#L142-L151.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

1 participant