Skip to content

Commit

Permalink
fix: 🐛 fix OOM on GPU for nrqm
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Mar 20, 2023
1 parent 08f8850 commit fcb7f6e
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pyiqa/archs/nrqm_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,16 @@ def norm_sender_normalized(pyr, num_scale=2, num_bands=6, blksz=3, eps=1e-12):
o_c = o_c.reshape(b, hw)
o_c = o_c - o_c.mean(dim=1, keepdim=True)

if tmp.shape[1] >= 2e5: # To avoid out of GPU memory
C_x = C_x.cpu()
tmp = tmp.cpu()
if hasattr(torch.linalg, 'lstsq'):
tmp_y = torch.linalg.lstsq(C_x.transpose(1, 2), tmp.transpose(1, 2)).solution.transpose(1, 2) * tmp / N
else:
warn(
"For numerical stability, we use torch.linal.lstsq to calculate matrix inverse for PyTorch > 1.9.0. The results might be slightly different if you use older version of PyTorch.")
tmp_y = (tmp @ torch.linalg.pinv(C_x)) * tmp / N
tmp_y = tmp_y.to(o_c)

z = tmp_y.sum(dim=2).sqrt()
mask = z != 0
Expand Down

0 comments on commit fcb7f6e

Please sign in to comment.