Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 58 additions & 58 deletions invokeai/app/services/invocation_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,36 @@
GIG = 1073741824


@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""

calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0


@dataclass
class NodeLog:
"""Class for tracking node usage"""

# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)


class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"

graph_execution_manager: ItemStorageABC["GraphExecutionState"]
# {graph_id => NodeLog}
_stats: Dict[str, NodeLog]
_cache_stats: Dict[str, CacheStats]
ram_used: float
ram_changed: float

@abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
"""
Expand Down Expand Up @@ -94,8 +121,6 @@ def update_invocation_stats(
invocation_type: str,
time_used: float,
vram_used: float,
ram_used: float,
ram_changed: float,
):
"""
Add timing information on execution of a node. Usually
Expand All @@ -104,8 +129,6 @@ def update_invocation_stats(
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
pass

Expand All @@ -116,25 +139,19 @@ def log_stats(self):
"""
pass

@abstractmethod
def update_mem_stats(
self,
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.

@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""

calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
cache_hits: int = 0
cache_misses: int = 0
cache_high_watermark: int = 0


@dataclass
class NodeLog:
"""Class for tracking node usage"""

# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
pass


class InvocationStatsService(InvocationStatsServiceBase):
Expand All @@ -152,12 +169,12 @@ def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"
class StatsContext:
"""Context manager for collecting statistics."""

invocation: BaseInvocation = None
collector: "InvocationStatsServiceBase" = None
graph_id: str = None
start_time: int = 0
ram_used: int = 0
model_manager: ModelManagerService = None
invocation: BaseInvocation
collector: "InvocationStatsServiceBase"
graph_id: str
start_time: float
ram_used: int
model_manager: ModelManagerService

def __init__(
self,
Expand All @@ -170,7 +187,7 @@ def __init__(
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
self.start_time = 0.0
self.ram_used = 0
self.model_manager = model_manager

Expand All @@ -191,7 +208,7 @@ def __exit__(self, *args):
)
self.collector.update_invocation_stats(
graph_id=self.graph_id,
invocation_type=self.invocation.type,
invocation_type=self.invocation.type, # type: ignore - `type` is not on the `BaseInvocation` model, but *is* on all invocations
time_used=time.time() - self.start_time,
vram_used=torch.cuda.max_memory_allocated() / GIG if torch.cuda.is_available() else 0.0,
)
Expand All @@ -202,11 +219,6 @@ def collect_stats(
graph_execution_state_id: str,
model_manager: ModelManagerService,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
self._cache_stats[graph_execution_state_id] = CacheStats()
Expand All @@ -217,7 +229,6 @@ def reset_all_stats(self):
self._stats = {}

def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try:
self._stats.pop(graph_execution_id)
except KeyError:
Expand All @@ -228,12 +239,6 @@ def update_mem_stats(
ram_used: float,
ram_changed: float,
):
"""
Update the collector with RAM memory usage info.

:param ram_used: How much RAM is currently in use.
:param ram_changed: How much RAM changed since last generation.
"""
self.ram_used = ram_used
self.ram_changed = ram_changed

Expand All @@ -244,16 +249,6 @@ def update_invocation_stats(
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
:param ram_used: Current RAM available (GB)
:param ram_changed: Change in RAM usage over course of the run (GB)
"""
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
Expand All @@ -262,14 +257,15 @@ def update_invocation_stats(
stats.max_vram = max(stats.max_vram, vram_used)

def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed when the execution of the graph
is complete.
"""
completed = set()
errored = set()
for graph_id, node_log in self._stats.items():
current_graph_state = self.graph_execution_manager.get(graph_id)
try:
current_graph_state = self.graph_execution_manager.get(graph_id)
except Exception:
errored.add(graph_id)
continue

if not current_graph_state.is_complete():
continue

Expand Down Expand Up @@ -302,3 +298,7 @@ def log_stats(self):
for graph_id in completed:
del self._stats[graph_id]
del self._cache_stats[graph_id]

for graph_id in errored:
del self._stats[graph_id]
del self._cache_stats[graph_id]