Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from dataclasses import dataclass
from functools import reduce # Required in Python 3
from typing import Tuple, Optional, List
from typing import Tuple, Optional, Callable
from warnings import warn

import torch
Expand All @@ -14,9 +14,6 @@
def prod(iterable):
return reduce(operator.mul, iterable, 1)

tensor = torch.Tensor


# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py

Expand Down Expand Up @@ -56,7 +53,10 @@ def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)


def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
def get_inverse_transform_indices(
transform_tile: Callable[[torch.Tensor], torch.Tensor],
tile_size: Tuple[int, int],
):
"""
Compute a permutation of indices that invert the specified (tiled) matrix transformation

Expand Down Expand Up @@ -496,7 +496,7 @@ class MatMul4Bit(torch.autograd.Function):
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

@staticmethod
def forward(ctx, A, B, out=None, bias=None, quant_state: F.QuantState = None):
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
Expand Down Expand Up @@ -549,10 +549,10 @@ def backward(ctx, grad_output):


def matmul(
A: tensor,
B: tensor,
out: tensor = None,
state: MatmulLtState = None,
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias=None
):
Expand All @@ -562,7 +562,7 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state)


def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None):
def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/cuda_setup/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt
system = platform.system()
if system == 'Windows':
CUDA_RUNTIME_LIBS: list = ["nvcuda.dll"]
CUDA_RUNTIME_LIBS = ["nvcuda.dll"]
else: # Linux or other
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2']
CUDA_RUNTIME_LIBS = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0', 'libcudart.so.12.1', 'libcudart.so.12.2']

# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths
backup_paths = []
Expand Down
Loading