Skip to content

[BUG] scale_lr fails for lr_scaling_method="sqrt" due to torch.sqrt on Python float #7733

@alismil

Description

@alismil

When using dynamic batching with lr_scaling_method="sqrt", training fails with:

"TypeError: sqrt(): argument 'input' (position 1) must be Tensor, not float"

The error originates from scale_lr() on line 159 in deepspeed/runtime/data_pipeline/data_sampling/variable_batch_size_and_lr.py:

return base_lr * torch.sqrt(batch_size / base_batch_size)

Here, batch_size and base_batch_size are Python integers, so batch_size / base_batch_size is a Python float. Passing this float to torch.sqrt() raises a TypeError. torch.sqrt should probably be replaced with math.sqrt here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions