Skip to content

Commit

Permalink
saving checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 27, 2023
1 parent 6146ee7 commit f22c3d2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
17 changes: 14 additions & 3 deletions magvit2_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
random_split_seed = 42,
valid_frac = 0.05,
validate_every_step = 100,
checkpoint_every_step = 100,
num_frames = 17,
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict(),
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(
self.valid_dataloader = DataLoader(valid_dataset, shuffle = True, drop_last = True, batch_size = batch_size)

self.validate_every_step = validate_every_step
self.checkpoint_every_step = checkpoint_every_step

# optimizers

Expand Down Expand Up @@ -159,6 +161,9 @@ def unwrapped_model(self):
def is_local_main(self):
return self.accelerator.is_local_main_process

def wait(self):
return self.accelerator.wait_for_everyone()

def print(self, msg):
return self.accelerator.print(msg)

Expand Down Expand Up @@ -211,7 +216,7 @@ def train_step(self, dl_iter):
self.optimizer.step()
self.optimizer.zero_grad()

self.accelerator.wait_for_everyone()
self.wait()

# update ema model

Expand Down Expand Up @@ -273,11 +278,17 @@ def train(self):

self.train_step(dl_iter)

self.accelerator.wait_for_everyone()
self.wait()

if self.is_main and not (step % self.validate_every_step):
self.valid_step(valid_dl_iter)

self.accelerator.wait_for_everyone()
self.wait()

if self.is_main and not (step % self.checkpoint_every_step):
checkpoint_num = step // self.checkpoint_every_step
self.save(f'./checkpoint.{checkpoint_num}.pt')

self.wait()

step += 1
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.42'
__version__ = '0.0.43'

0 comments on commit f22c3d2

Please sign in to comment.