In [1]:
import time
import torch
import matplotlib.pyplot as plt
from deepinv.models import DRUNet
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n_channels = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img_shape = (5, 1, 128, 128)

In [3]:
denoiser = DRUNet(
    in_channels=n_channels,
    out_channels=n_channels,
    pretrained="download",  # automatically downloads the pretrained weights, set to a path to use custom weights.
    train=False,
    device=device,
)

In [4]:
img = torch.randn(img_shape, device=device, dtype=torch.cfloat)
img_abs = torch.abs(img)
img_phase = torch.angle(img)
img_abs.dtype, img_phase.dtype

(torch.float32, torch.float32)

In [5]:
# process amplitude and phase separately
start = time.time()
for _ in range(100):
    img = torch.randn(img_shape, device=device, dtype=torch.cfloat)
    img_abs = torch.abs(img)
    img_phase = torch.angle(img)
    img_abs_denoised = denoiser(img_abs, sigma=0.03)
    img_phase_denoised = denoiser(img_phase, sigma=0.03)
    img_denoised = img_abs_denoised * torch.exp(1j * img_phase_denoised)
end = time.time()
print("Time taken: ", end - start)

Time taken:  157.50350689888


In [6]:
# process amplitude and phase together
start = time.time()
for _ in range(100):
    img = torch.randn(img_shape, device=device, dtype=torch.cfloat)
    img_abs = torch.abs(img)
    img_phase = torch.angle(img)
    noisy_batch = torch.cat((img_abs, img_phase), 0)
    denoised_batch = denoiser(noisy_batch, sigma=0.03)
    #print(denoised_batch[:img_abs.shape[0]].shape, denoised_batch[img_phase.shape[0]:].shape)
    img_denoised = denoised_batch[:img_abs.shape[0]] * torch.exp(1j * denoised_batch[img_phase.shape[0]:])
end = time.time()
print("Time taken: ", end - start)

Time taken:  155.910080909729


In [103]:
torch.manual_seed(0)

x = torch.randn(3, 3, 9, 9)

x_1 = x/torch.linalg.norm(x, keepdim=True)  # Normalizes the full batch
x_2 = x/torch.linalg.norm(x, dim=(2,3), keepdim=True)  # Normalizes each sub-tensor of the batch

print(x_1[2].norm())  # torch.linalg.norm(x_1) = 1
print(x_2[2].norm())  # will contain unit elements

torch.linalg.norm(x_2, dim=(2,3))

tensor(0.5055)
tensor(1.)


tensor([[1.0000],
        [1.0000],
        [1.0000]])

In [110]:
x = torch.randn(3, 3, 9, 9)
x[0] = x[0] / torch.linalg.norm(x[0])
x[1] = x[1] / torch.linalg.norm(x[1])
x[2] = x[2] / torch.linalg.norm(x[2])
x[0].norm()

tensor(1.)

In [80]:
y = torch.randn(3, 100)
res = y / torch.mean(y, dim=1,keepdim=True)
res[0].mean(), res[1].mean(), res[2].mean()

(tensor(1.0000), tensor(1.0000), tensor(1.0000))

In [183]:
x = torch.randn(3, 3, 23, 17)
normalized_x = torch.stack([subtensor/subtensor.norm() for subtensor in x])

x[0].norm(), normalized_x[0].norm(), x[1].norm(), normalized_x[1].norm(), x[2].norm(), normalized_x[2].norm()

(tensor(34.5784),
 tensor(1.),
 tensor(34.7137),
 tensor(1.0000),
 tensor(32.8030),
 tensor(1.))

In [114]:
for subtensor in x:
    print(subtensor.shape)

torch.Size([3, 9, 9])
torch.Size([3, 9, 9])
torch.Size([3, 9, 9])
