Skip to content

Commit

Permalink
Update model_manager.py - trainer config change - Issue #240
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronphilip19 committed Apr 26, 2024
1 parent 80fe01b commit e34c71f
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions zamba/models/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,19 +280,17 @@ def train_model(

accelerator, devices = configure_accelerator_and_devices_from_gpus(train_config.gpus)


multiprocessing_strategy = getattr(train_config,"multiprocessing_strategy",None)

trainer = pl.Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=train_config.max_epochs,
logger=tensorboard_logger,
callbacks=callbacks,
fast_dev_run=train_config.dry_run,
strategy=(
DDPStrategy(find_unused_parameters=False)
if (data_module.multiprocessing_context is not None) and (train_config.gpus > 1)
else "auto"
),
strategy = multiprocessing_strategy,
)
#Set the strategy within trainer to reflect changes

if video_loader_config.cache_dir is None:
logger.info("No cache dir is specified. Videos will not be cached.")
Expand Down

0 comments on commit e34c71f

Please sign in to comment.