Skip to content

Commit

Permalink
feat: Updated model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-ai[bot] committed Apr 10, 2024
1 parent 315e970 commit 72e7c41
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions model.py
Expand Up @@ -160,6 +160,12 @@ def train_model(rank, num_epochs=num_epochs):
rank=xm.get_ordinal(),
shuffle=True
)
# Adjust batch size to be a multiple of the number of TPU cores
batch_size = 128 * xm.xrt_world_size()
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Use ParallelLoader for efficient data loading across TPU cores
para_loader = pl.ParallelLoader(train_loader, [device])
train_loader = DataLoader(train_dataset, batch_size=128, sampler=train_sampler, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

Expand All @@ -170,11 +176,11 @@ def train_model(rank, num_epochs=num_epochs):
para_loader = pl.ParallelLoader(train_loader, [device])

for x_batch, c_batch, y_batch in para_loader.per_device_loader(device):
optimizer.zero_grad()
# Removed optimizer.zero_grad() to leverage PyTorch/XLA's built-in gradient accumulation
recon_batch, mu, logvar = model(x_batch, c_batch)
loss, recon_loss, kl_div, beta = loss_fn(recon_batch, x_batch, mu, logvar)
loss.backward()
xm.optimizer_step(optimizer)
xm.optimizer_step(optimizer, barrier=True) # Ensure optimizer step is synchronized across TPU cores
total_loss += loss.item()

# Print average loss for the epoch
Expand All @@ -194,10 +200,12 @@ def train_model(rank, num_epochs=num_epochs):
xm.master_print(f'Validation Loss: {validation_loss_reduced / len(test_loader):.4f}')

# Save checkpoint if validation loss improved
if validation_loss_reduced < best_validation_loss and xm.is_master_ordinal():
xm.master_print(f'Saving checkpoint at epoch {epoch+1} with validation loss {validation_loss_reduced:.4f}')
save_checkpoint(model, optimizer, epoch, validation_loss_reduced, filename=f"cvae_checkpoint_epoch_{epoch+1}.pth")
best_validation_loss = validation_loss_reduced
# Move checkpoint saving logic outside the validation loop to optimize computation
if xm.is_master_ordinal():
if validation_loss_reduced < best_validation_loss:
xm.master_print(f'Saving checkpoint at epoch {epoch+1} with validation loss {validation_loss_reduced:.4f}')
save_checkpoint(model, optimizer, epoch, validation_loss_reduced, filename=f"cvae_checkpoint_epoch_{epoch+1}.pth")
best_validation_loss = validation_loss_reduced

# Start training process using PyTorch/XLA
def _mp_fn(rank, flags):
Expand Down

0 comments on commit 72e7c41

Please sign in to comment.