Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 64 additions & 23 deletions captum/_utils/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,64 @@
# pyre-strict

import typing
from types import TracebackType
from typing import (
Any,
Callable,
cast,
Iterable,
Iterator,
Literal,
Optional,
Protocol,
runtime_checkable,
TextIO,
Type,
TypeVar,
Union,
)

from tqdm.auto import tqdm
from typing_extensions import Self

T = TypeVar("T")
IterableType = TypeVar("IterableType")
IterableType = TypeVar("IterableType", covariant=True)


@runtime_checkable
class BaseProgress(Protocol):
"""
Protocol defining the base progress bar interfaced with
context manager support.
Note: This protocol is based on the tqdm type stubs.
"""

def __enter__(self) -> Self: ...

def __exit__(
self,
exc_type: object,
exc_value: object,
exc_traceback: object,
) -> None: ...

def close(self) -> None: ...


@runtime_checkable
class IterableProgress(BaseProgress, Iterable[IterableType], Protocol[IterableType]):
"""Protocol for progress bars that support iteration.

Note: This protocol is based on the tqdm type stubs.
"""

...


@runtime_checkable
class Progress(BaseProgress, Protocol):
"""Protocol for progress bars that support manual updates.
Note: This protocol is based on the tqdm type stubs.
"""

# This is a weird definition of Progress, but it's what tqdm does.
def update(self, n: float | None = 1) -> bool | None: ...


class DisableErrorIOWrapper(object):
Expand Down Expand Up @@ -56,7 +95,7 @@ def flush(self, *args: object, **kwargs: object) -> None:
return self._wrapped_run(self._wrapped.flush, *args, **kwargs)


class NullProgress(Iterable[IterableType]):
class NullProgress(IterableProgress[IterableType], Progress):
"""Passthrough class that implements the progress API.

This class implements the tqdm and SimpleProgressBar api but
Expand All @@ -74,25 +113,27 @@ def __init__(
del args, kwargs
self.iterable = iterable

def __enter__(self) -> "NullProgress[IterableType]":
def __iter__(self) -> Iterator[IterableType]:
iterable = self.iterable
if not iterable:
yield from ()
return
for it in iterable:
yield it

def __enter__(self) -> Self:
return self

def __exit__(
self,
exc_type: Union[Type[BaseException], None],
exc_value: Union[BaseException, None],
exc_traceback: Union[TracebackType, None],
) -> Literal[False]:
return False

def __iter__(self) -> Iterator[IterableType]:
if not self.iterable:
return
for it in cast(Iterable[IterableType], self.iterable):
yield it
exc_type: object,
exc_value: object,
exc_traceback: object,
) -> None:
self.close()

def update(self, amount: int = 1) -> None:
pass
def update(self, n: float | None = 1) -> bool | None:
return None

def close(self) -> None:
pass
Expand All @@ -106,7 +147,7 @@ def progress(
file: Optional[TextIO] = None,
mininterval: float = 0.5,
**kwargs: object,
) -> tqdm: ...
) -> Progress: ...


@typing.overload
Expand All @@ -117,7 +158,7 @@ def progress(
file: Optional[TextIO] = None,
mininterval: float = 0.5,
**kwargs: object,
) -> tqdm: ...
) -> IterableProgress[IterableType]: ...


def progress(
Expand All @@ -127,7 +168,7 @@ def progress(
file: Optional[TextIO] = None,
mininterval: float = 0.5,
**kwargs: object,
) -> tqdm:
) -> Union[Progress, IterableProgress[IterableType]]:
return tqdm(
iterable,
desc=desc,
Expand Down
34 changes: 2 additions & 32 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,7 @@

import logging
import math
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Optional,
Protocol,
Tuple,
TypeVar,
Union,
)
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union

import torch
from captum._utils.common import (
Expand All @@ -31,7 +19,7 @@
_run_forward,
)
from captum._utils.exceptions import FeatureAblationFutureError
from captum._utils.progress import progress
from captum._utils.progress import NullProgress, progress, Progress
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
from captum.attr._utils.attribution import PerturbationAttribution
from captum.attr._utils.common import (
Expand All @@ -43,27 +31,9 @@
from torch.futures import collect_all, Future


IterableType = TypeVar("IterableType")

logger: logging.Logger = logging.getLogger(__name__)


class Progress(Protocol):
def update(self, n: int = 1) -> Optional[bool]:
"""TQDM Update method signature."""

def close(self) -> None:
"""TQDM Close method signature."""


class NullProgress:
def update(self, n: int = 1) -> Optional[bool]:
return None

def close(self) -> None:
return None


def _parse_forward_out(forward_output: object) -> Tensor:
"""
A temp wrapper for global _run_forward util to force forward output
Expand Down
Loading