-
Notifications
You must be signed in to change notification settings - Fork 116
Open
Description
Hi, while I'm trying to resume training, logging doesn't seem to be working as expected. If I try to resume, the Loop.log function does not print anymore after each iteration. For example, I have the following train and resume functions:
def train(model, train_data, opts \\ []) do
epochs = Keyword.get(opts, :epochs, 1)
log_every = Keyword.get(opts, :log_every, 1)
checkpoint_every = Keyword.get(opts, :checkpoint_every, 100)
optimizer = build_optimizer(opts)
loss = build_loss(opts)
model
|> Loop.trainer(loss, optimizer)
|> Loop.log(&log_message/1, event: :iteration_completed, filter: [every: log_every])
|> Loop.checkpoint(event: :iteration_completed, filter: [every: checkpoint_every], path: @checkpoint_path)
|> Loop.run(train_data, %{}, epochs: epochs, garbage_collect: true, force_garbage_collection?: true)
end
def resume(model, train_data, state_path, opts \\ []) do
Logger.info("Resuming from checkpoint: #{state_path}.")
epochs = Keyword.get(opts, :epochs, 1)
log_every = Keyword.get(opts, :log_every, 100)
checkpoint_every = Keyword.get(opts, :checkpoint_every, 100)
optimizer = build_optimizer(opts)
loss = build_loss(opts)
state = state_path |> File.read!() |> Loop.deserialize_state()
model
|> Loop.trainer(loss, optimizer)
|> Loop.log(&log_message/1, event: :iteration_completed, filter: [every: log_every])
|> Loop.checkpoint(event: :iteration_completed, filter: [every: checkpoint_every], path: @checkpoint_path)
|> Loop.from_state(state)
|> Loop.run(train_data, %{}, epochs: epochs, garbage_collect: true, force_garbage_collection?: true)
end- Axon: 0.8.0
- Nx: 0.10.0
- EMLX: 0.2.0
Am I doing something wrong here?
Metadata
Metadata
Assignees
Labels
No labels