Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jan 11, 2024

No description provided.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 11, 2024
@vkuzo vkuzo requested a review from drisspg January 11, 2024 22:00
@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

README.md Outdated
# User API, subject to change

## single GPU
We provide two types of scaling for per-tensor scaling of tensors: dynamic and delayed.
Copy link
Contributor

@drisspg drisspg Jan 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds kinda of weird, what about "We provide two scaling strategies: per-tensor dynamic and delayed."

README.md Outdated

# optional: use FSDP. Note that workarounds are needed for autocast+compile+FSDP+float8 to work
from float8_experimental import config
config.enable_amax_init = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do like

if "fsdp" do the config stuff otherwise you don't need to

README.md Outdated

We are using a module swap UX to keep things simple. If the user model has `torch.nn.Linear` modules or their `fairscale` TP/SP equivalents,
we can convert them to float8. `F.linear`, `torch.mm`, `torch.matmul` are not supported at the moment.
# upcoming work
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we make a tracking issue for upcoming work and link that here. Just so we don't forget to remove from reademe, either way is okay though


## multi GPU

### TP/SP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we still have the TP/SP stuff documented just saying we know this doesn't work with compile and planto have dtensor integration for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it's not that usable without compile though :( I'm more excited about just deleting this


### Tensor subclasses

We are using tensor subclasses (`Float8Tensor`) to write modular code which satisfies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is still fine no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will copy-paste to design

# the rest of the flow is the same as the single GPU flow
```

# high level technical design
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same with this, I don't mind the high level design

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of the below is slightly outdated and also more of dev docs than user README. I think we can have an issue about high level design, it would be nice to edit that without OSS->Meta PR sync. I can copy-paste this there.

* `float8_experimental/float8_linear.py` - `Float8Linear` (main user facing entry point for delayed scaling)
* `float8_experimental/float8_dynamic_linear.py` - `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
* `float8_experimental/float8_tensor.py` - `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
* `float8_experimental/tp_linear.py` - `Float8ColumnParallelLinear` / `Float8RowParallelLinear` (TP/SP versions of float8 linear)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same with above

@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@vkuzo merged this pull request in d0af81a.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants