Skip to content

Commit

Permalink
[minor] ThreadLocal to ThreadLocalCheckpointState dataclass (#1007)
Browse files Browse the repository at this point in the history
* ThreadLocal to ThreadLocalCheckpointState dataclass

* remove notes
  • Loading branch information
crutcher committed Jun 19, 2022
1 parent 32b0b98 commit 9c195fe
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions fairscale/nn/checkpoint/checkpoint_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

from contextlib import contextmanager
from dataclasses import dataclass
import functools
import threading
from typing import Any, Dict, Generator, Optional, Tuple
Expand All @@ -21,14 +22,14 @@

# https://docs.python.org/3/library/threading.html#thread-local-data
# Manage the checkpoint context with thread-local data.
class ThreadLocal(threading.local):
def __init__(self) -> None:
self.is_checkpointing = False
self.is_recomputing = False
self.is_checkpointing_disabled = False
@dataclass
class ThreadLocalCheckpointingState(threading.local):
is_checkpointing: bool = False
is_recomputing: bool = False
is_checkpointing_disabled: bool = False


thread_local = ThreadLocal()
thread_local = ThreadLocalCheckpointingState()


@contextmanager
Expand Down

0 comments on commit 9c195fe

Please sign in to comment.