In [1]:
import torch

### Locally disabling gradient computation

The context managers `torch.no_grad()`, `torch.enable_grad()`, and `torch.set_grad_enabled()` are helpful for locally disabling and enabling gradient computation.

These context managers are thread local, so they won't work if you send work to another thread using the `threading` module, etc.

In [3]:
x = torch.zeros(1, requires_grad=True)

with torch.no_grad():
    y = x * 2

y.requires_grad

False

In [4]:
is_train = False
with torch.set_grad_enabled(is_train):
    y = x * 2

y.requires_grad

False

In [6]:
torch.set_grad_enabled(True)
y = x * 2
y.requires_grad

True

In [7]:
torch.set_grad_enabled(False)
y = x * 2
y.requires_grad

False

## Context managers

1. [torch.no_grad](#torch.no_grad)
2. [torch.enable_grad](#torch.enable_grad)
3. [torch.set_grad_enabled](#torch.set_grad_enabled)
4. [torch.is_grad_enabled](#torch.is_grad_enabled)
5. [torch.inference_mode](#torch.inference_mode)
6. [torch.is_inference_mode_enabled](#torch.is_inference_mode_enabled)

<a id="torch.no_grad"></a>
### 1. torch.no_grad

Context-manager that disables gradient calculation.

Disabling gradient calculation is useful for **inference**, when you are sure that you will not call `Tensor.backward()`. It will reduce memory consumption for computations that would otherwise have `requires_grad=True`.

In this mode, the result of every computation will have `require_grad=False`, even when the inputs have `require_grad=True`. There is an exception! All factory functions that create new Tensor and take a requires_grad kwarg, will NOT be affected by this mode.

This context manager is thread local; it will not affect computation in other threads.

Also functions as decorator.

In [10]:
x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
    y = x * 2
y.requires_grad

False

In [11]:
@torch.no_grad()
def doubler(x):
    return x * 2
z = doubler(x)
z.requires_grad

False

In [12]:
@torch.no_grad()
def tripler(x):
    return x * 3
z = tripler(x)
z.requires_grad

False

In [13]:
# factory function exception
with torch.no_grad():
    a = torch.nn.Parameter(torch.randn(10))
a.requires_grad

True

<a id="torch.enable_grad"></a>

### 2. torch.enable_grad

`torch.enable_grad(orig_func=None)`

Context-manager that enables gradient calculation.

Enables gradient calculation, if it has been disabled via [`no_grad()`](#torch.no_grad)  or [`set_grad_enabled`](#torch.set_grad_enabled).

This context manager is thread local; it will not affect computation in other threads.

Also functions as a decotator.

In [14]:
x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
    with torch.enable_grad():
        y = x * 2
y.requires_grad

True

In [15]:
y.backward()
x.grad

tensor([2.])

In [16]:
@torch.enable_grad()
def doubler(x):
    return x * 2
with torch.no_grad():
    z = doubler(x)
z.requires_grad

True

In [17]:
@torch.enable_grad
def tripler(x):
    return x * 3
with torch.no_grad():
    z = tripler(x)
z.requires_grad

True

<a id="torch.set_grad_enabled"></a>

### 3. torch.set_grad_enabled

`torch.set_grad_enabled(mode)`

Context-manager that sets gradient calcuation on or off.

`set_grad_enabled` will enable or disable grads based on its arguments `mode`. It can be used as a context-manager or as a fucntion.

This context manager is thread local; it will not affect computation in other threads.

In [18]:
x = torch.tensor([1.], requires_grad=True)
is_train = False
with torch.set_grad_enabled(is_train):
    y = x * 2
y.requires_grad

False

In [20]:
#used as a function
_ = torch.set_grad_enabled(True)
y = x * 2
y.requires_grad

True

In [23]:
_ = torch.set_grad_enabled(False)
y = x * 2
y.requires_grad

False

<a id="torch.sis_grad_enabled"></a>

### 4. torch.is_grad_enabled

`torch.is_grad_enabled()`

Returns True if grad mode is currently enabled.

In [24]:
torch.is_grad_enabled()

False

In [25]:
_ = torch.set_grad_enabled(True)
torch.is_grad_enabled()

True

<a id="torch.inference_mode"></a>

### 5. torch.inference_mode

`torch.inference_mode(mode=True)`

Context manager that enables or disables inference mode.

InferenceMode is a new context manager analogous to `no_grad` to be used when you are certain your operations will have no interactions with autograd (e.g., model_training). Code ran under this mode gets better performance by disabling view tracking and version counter bumps. Note that unlike some other mechanism that locally enable or diable grad, entering inference_mode also disables to `forward-mode AD`.

This context manager is thread local; it will not affect computation in other threads.

Also functions as a decorator.

In [26]:
x = torch.ones(1, 2, 3, requires_grad=True)
with torch.inference_mode():
    y = x * x
y.requires_grad

False

In [27]:
y._version

RuntimeError: Inference tensors do not track version counter.

In [28]:
@torch.inference_mode()
def func(x):
    return x * x
out = func(x)
out.requires_grad

False

In [29]:
@torch.inference_mode
def doubler(x):
    return x * 2
out = doubler(x)
out.requires_grad

False

<a id="torch.is_inference_mode_enabled"></a>

### 6. torch.is_inference_mode_enabled

`torch.is_inference_mode_enabled()`

Returns True if inference mode is currently enabled.

In [30]:
torch.is_inference_mode_enabled()

False

In [33]:
with torch.inference_mode():
    print(torch.is_inference_mode_enabled())

True
