diff --git a/torchft/multiprocessing.py b/torchft/multiprocessing.py index 71b6aa3..7ead1e9 100644 --- a/torchft/multiprocessing.py +++ b/torchft/multiprocessing.py @@ -8,7 +8,7 @@ class _MonitoredPipe: - def __init__(self, pipe: "Connection[object, object]") -> None: + def __init__(self, pipe: "Connection[object, object]") -> None: # type: ignore self._pipe = pipe def send(self, obj: object) -> None: diff --git a/torchft/multiprocessing_test.py b/torchft/multiprocessing_test.py index 5f1ebfa..37dcf96 100644 --- a/torchft/multiprocessing_test.py +++ b/torchft/multiprocessing_test.py @@ -6,11 +6,11 @@ from torchft.multiprocessing import _MonitoredPipe -def pipe_get(q: "Connection[object, object]") -> None: +def pipe_get(q: "Connection[object, object]") -> None: # type: ignore q.recv() -def pipe_put(q: "Connection[object, object]") -> None: +def pipe_put(q: "Connection[object, object]") -> None: # type: ignore q.recv() q.send(1) diff --git a/torchft/process_group.py b/torchft/process_group.py index 3313156..74e4419 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -1443,8 +1443,8 @@ def _worker( store_addr: str, rank: int, world_size: int, - req_pipe: "Connection[object, object]", - future_pipe: "Connection[object, object]", + req_pipe: "Connection[object, object]", # type: ignore + future_pipe: "Connection[object, object]", # type: ignore curr_device: int, ) -> None: try: diff --git a/torchft/quantization.py b/torchft/quantization.py index e257ba2..2bcdf4e 100644 --- a/torchft/quantization.py +++ b/torchft/quantization.py @@ -125,7 +125,7 @@ def _fused_kernel_quantize_into_fp8( # be written o_curr_ptr = o_ptr + o_offset o_scale_ptr = o_curr_ptr.to(tl.pointer_type(SCALE_TL_DTYPE)) - o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE)) + o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE)) # type: ignore # Compute maximum for the current row block by block col_offsets = tl.arange(0, BLOCK_SIZE) @@ -233,7 +233,7 @@ def _fused_kernel_dequantize_from_fp8( # written o_curr_ptr = o_ptr + o_offset o_scale_ptr = o_curr_ptr.to(tl.pointer_type(SCALE_TL_DTYPE)) - o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE)) + o_quant_ptr = (o_curr_ptr + SCALE_TL_DTYPE_BYTES).to(tl.pointer_type(TL_FP8_TYPE)) # type: ignore # Load row scale i_row_scale = tl.load(o_scale_ptr) @@ -342,7 +342,7 @@ def _fused_kernel_reduce_fp8( o_rank_row_ptr = o_ptr + all_reduce_rank * o_size_bytes_per_rank + o_offset o_rank_scale_ptr = o_rank_row_ptr.to(tl.pointer_type(SCALE_TL_DTYPE)) o_rank_quant_ptr = (o_rank_row_ptr + SCALE_TL_DTYPE_BYTES).to( - tl.pointer_type(TL_FP8_TYPE) + tl.pointer_type(TL_FP8_TYPE) # type: ignore ) col_offsets = tl.arange(0, BLOCK_SIZE) @@ -411,7 +411,7 @@ def _fused_kernel_accumulate_block( # Load row scale and block of quantized row o_scale_ptr = o_row_ptr.to(tl.pointer_type(tl.float32)) o_quant_ptr = (o_row_ptr + SCALE_TL_DTYPE_BYTES).to( - tl.pointer_type(TL_FP8_TYPE) + tl.pointer_type(TL_FP8_TYPE) # type: ignore ) o_row_scale = tl.load(o_scale_ptr) @@ -580,7 +580,7 @@ def fused_quantize_into_fp8( output, output_size // all_reduce_group_size, all_reduce_group_size, - BLOCK_SIZE=BLOCK_SIZE_T, + BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore TL_FP8_TYPE=_get_fp8_type(), TL_FP8_MAX=_get_fp8_max(), ) @@ -630,7 +630,7 @@ def fused_dequantize_from_fp8( output, output_size // all_reduce_group_size, all_reduce_group_size, - BLOCK_SIZE=BLOCK_SIZE_T, + BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore TL_FP8_TYPE=_get_fp8_type(), ) @@ -680,7 +680,7 @@ def fused_reduce_fp8( all_reduce_group_size, all_reduce_rank, 1.0 if reduce_op == ReduceOp.SUM else float(all_reduce_group_size), - BLOCK_SIZE=BLOCK_SIZE_T, + BLOCK_SIZE=BLOCK_SIZE_T, # type: ignore TL_FP8_TYPE=_get_fp8_type(), TL_FP8_MAX=_get_fp8_max(), )