Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Added module override, bnb.nn.Embedding #13 #15 #19
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Nov 29, 2021
1 parent 3cff679 commit 20e1677
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 6 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ Docs:
Features:
- Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer.
- Added AdamW (copy of Adam with weight decay init 1e-2). #10
- Introduced ModuleConfig overrides which can be seamlessly be used at initialization time of a module.
- Added `bnb.nn.Embedding` layer which runs at 32-bit but without the layernorm. This works well if you need to fine-tune pretrained models that do not have a embedding layer norm. #19

Bug fixes:
- Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13
- Fixed an unsafe use of eval. #8
- Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15

Docs:
- Added instructions how to solve "\_\_fatbinwrap_" errors.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m
## Errors

1. RuntimeError: CUDA error: no kernel image is available for execution on the device. [Solution](errors_and_solutions.md#No-kernel-image-available)
2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_)

## Compile from source

Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import StableEmbedding
from .modules import StableEmbedding, Embedding
33 changes: 31 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optiona
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
self.norm = torch.nn.LayerNorm(embedding_dim)
GlobalOptimManager.get_instance().register_parameters(self.weight)
GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32)
GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})

def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
Expand All @@ -42,3 +41,33 @@ def forward(self, input: Tensor) -> Tensor:
self.norm_type, self.scale_grad_by_freq, self.sparse)

return self.norm(emb)


class Embedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
sparse: bool = False, _weight: Optional[Tensor] = None) -> None:
super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight)
GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})

def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero()

''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases.
'''
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)

def forward(self, input: Tensor) -> Tensor:
emb = F.embedding(
input, self.weight, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)

return emb
31 changes: 28 additions & 3 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def initialize(self):
self.index2config = {}
self.optimizer = None
self.uses_config_override = False
self.module_weight_config_triple = []

@classmethod
def get_instance(cls):
Expand Down Expand Up @@ -77,12 +78,16 @@ def override_config(self, parameters, key=None, value=None, key_value_dict=None)
if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
else: self.pid2config[id(p)] = key_value_dict

def register_module_override(self, module, param_name, config):
self.module_weight_config_triple.append((module, param_name, config))



class Optimizer8bit(torch.optim.Optimizer):

def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
self.checked_if_on_gpu = False
self.initialized = False
self.name2qmap = {}

self.mng = GlobalOptimManager.get_instance()
Expand Down Expand Up @@ -172,7 +177,6 @@ def update_group(group, new_group):
self.__setstate__({'state': state, 'param_groups': param_groups})

def to_gpu(self):
self.checked_if_on_gpu = True
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']):
if p in self.state:
Expand All @@ -181,6 +185,23 @@ def to_gpu(self):
if isinstance(v, torch.Tensor):
self.state[p][k] = v.to(p.device)

def check_overrides(self):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter)
found = False
for gindex, group in enumerate(self.param_groups):
if found: break
for pindex, p in enumerate(group['params']):
if found: break
if id(p) == id(pmodule):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)]
found = True

@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Expand All @@ -196,7 +217,11 @@ def step(self, closure=None):

overflows = []

if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training
if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True

for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']):
if p.grad is None:
Expand Down
13 changes: 13 additions & 0 deletions errors_and_solutions.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,16 @@ If you are feeling lucky, you can also try to compile the library from source. T


__If you encounter any other error not listed here please create an issue. This will help resolve your problem and will help out others in the future.


# fatbinwrap

This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your $PATH and $LD_LIBRARY_PATH variable. In the conda base environment you can find the library under:
```bash
ls $CONDA_PREFIX/lib/*cudart*
```
Make sure this path is appended to the `LD_LIBRARY_PATH` so bnb can find the CUDA runtime environment library (cudart).

If this does not fix the issue, please try [compilation from source](compile_from_source.md) next.

If this does not work, please open an issue and paste the printed environment if you call `make` and the associated error when running bnb.
14 changes: 14 additions & 0 deletions howto_config_override.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details

For global overrides in many different places in your code you can do:
```python
import torch
import bitsandbytes as bnb
Expand All @@ -24,3 +25,16 @@ mng.override_config([model.special.weight, model.also_special.weight],
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
```
Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm`

For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager:
```python
class MyModule(torch.nn.Module):
def __init__(din, dout):
super(MyModule, self).__init__()
self.linear = torch.nn.Linear(din, dout)
# optimization will happen in 32-bit and
# learning rate will be set to 0.0001 independent of the main learning rate
config = {'optim_bits': 32, 'lr' : 0.0001}
GlobalOptimManager.get_instance().register_module_override(self, 'weight', config)

```
46 changes: 46 additions & 0 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
import bitsandbytes as bnb

from itertools import product

from bitsandbytes import functional as F


@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
def test_embeddings(embcls):
bnb.optim.GlobalOptimManager.get_instance().initialize()
emb1 = torch.nn.Embedding(100, 512).cuda()
emb2 = embcls(100, 512).cuda()

adam1 = bnb.optim.Adam8bit(emb1.parameters())
adam2 = bnb.optim.Adam8bit(emb2.parameters())

batches = torch.randint(1, 100, size=(100, 4, 32)).cuda()

for i in range(100):
batch = batches[i]

embedded1 = emb1(batch)
embedded2 = emb2(batch)

l1 = embedded1.mean()
l2 = embedded2.mean()

l1.backward()
l2.backward()

adam1.step()
adam2.step()

adam1.zero_grad()
adam2.zero_grad()

assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8
assert adam2.state[emb2.weight]['state1'].dtype == torch.float32


0 comments on commit 20e1677

Please sign in to comment.