Skip to content

Commit

Permalink
Change no_grad -> inference_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Corpsecreate committed Feb 13, 2024
1 parent 1ddad19 commit 2dcafe4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion neosr/archs/swinir_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward(self, x, mask=None):
# The context manager might break graphs when using .compile(),
# solution seems to use it outside SDPA instead.

with torch.no_grad():
with torch.inference_mode():
with torch.backends.cuda.sdp_kernel(enable_math=False):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=self.scale, dropout_p=self.dropout_p)
x = x.transpose(1, 2).reshape(b_, n, c)
Expand Down
4 changes: 2 additions & 2 deletions neosr/models/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def test(self):
img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')

self.net_g.eval()
with torch.no_grad():
with torch.inference_mode():
self.output = self.net_g(img)
self.net_g.train()

Expand Down Expand Up @@ -748,7 +748,7 @@ def reduce_loss_dict(self, loss_dict):
Args:
loss_dict (OrderedDict): Loss dict.
"""
with torch.no_grad():
with torch.inference_mode():
if self.opt['dist']:
keys = []
losses = []
Expand Down
4 changes: 2 additions & 2 deletions neosr/models/otf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, opt):
self.queue_size = opt.get('queue_size', 180)
self.device = torch.device('cuda')

@torch.no_grad()
@torch.inference_mode()
def _dequeue_and_enqueue(self):
"""It is the training pair pool for increasing the diversity in a batch.
Expand Down Expand Up @@ -61,7 +61,7 @@ def _dequeue_and_enqueue(self):
b, :, :, :] = self.gt.clone()
self.queue_ptr = self.queue_ptr + b

@torch.no_grad()
@torch.inference_mode()
def feed_data(self, data):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
"""
Expand Down

0 comments on commit 2dcafe4

Please sign in to comment.