Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

The _FabricModule cannot be jitted after #78 #95

Closed
carmocca opened this issue Mar 28, 2024 · 0 comments 路 Fixed by #98
Closed

The _FabricModule cannot be jitted after #78 #95

carmocca opened this issue Mar 28, 2024 · 0 comments 路 Fixed by #98
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@carmocca
Copy link
Contributor

馃悰 Bug

extensions/thunder/pretrain.py:146: in setup
    main(
extensions/thunder/pretrain.py:233: in main
    fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
extensions/thunder/pretrain.py:253: in fit
    validate(fabric, model, val_dataloader, max_iters=2)  # sanity check
../nightly-env/lib/python3.10/site-packages/torch/utils/_contextlib.py:115: in decorate_context
    return func(*args, **kwargs)
extensions/thunder/pretrain.py:389: in validate
    loss = forward_and_loss(model, input_ids, targets)
../lightning-thunder/thunder/__init__.py:629: in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
../lightning-thunder/thunder/__init__.py:262: in cache_info_wrapper
    res = fn(*args, **kwargs)
../lightning-thunder/thunder/__init__.py:504: in get_computation_and_inputs
    prologue_trc, computation_trc, *maybe_epilogue = interpreter(
../lightning-thunder/thunder/__init__.py:175: in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
../lightning-thunder/thunder/core/jit_ext.py:1430: in thunder_general_jit
    result = jfn(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6669: in fn_
    raise e
../lightning-thunder/thunder/core/interpreter.py:6632: in fn_2
    return fn(*args, **kwargs)
extensions/thunder/pretrain.py:371: in forward_and_loss
    logits = model(input_ids)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../lightning/src/lightning/fabric/wrappers.py:142: in forward
    with precision.forward_context():
../lightning/src/lightning/fabric/plugins/precision/half.py:54: in forward_context
    return self.tensor_init_context()
../lightning/src/lightning/fabric/plugins/precision/half.py:46: in tensor_init_context
    return _DtypeContextManager(self._desired_input_dtype)
../lightning-thunder/thunder/core/interpreter.py:6031: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

    def __init__(self, dtype: torch.dtype) -> None:
>       self._previous_dtype: torch.dtype = torch.get_default_dtype()
E       NotImplementedError: Trying to call function torch.get_default_dtype, but it is not yet supported. Please file an issue requesting support. To find out which operations are not yet recongnized by `thunder.jit`, please run `examine` as per:
E       
E       from thunder.examine import examine
E       examine(<your thunder.jit callable argument>, ...)

../lightning/src/lightning/fabric/plugins/precision/utils.py:33: NotImplementedError

Jitting the _FabricModule is currently necessary to compile the joint forward and loss

To Reproduce

from lightning import Fabric
import torch
import thunder

fabric = Fabric(devices=1, precision="16-true")
model = torch.nn.Linear(1, 1, bias=False, device=fabric.device)
x = torch.randn(1, 1)
x = fabric.to_device(x)

fmodel = fabric.setup(model)
tmodel = thunder.jit(fmodel)

print(tmodel(x))

cc @nikitaved

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants