In [6]:
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

import deepinv

from deepinv.physics.phase_retrieval import RandomPhaseRetrieval, PseudoRandomPhaseRetrieval
from deepinv.utils import randn_like
from deepinv.physics.forward import adjoint_function
from deepinv.optim.data_fidelity import L2, AmplitudeLoss
from deepinv.utils.plotting import plot
from deepinv.utils.demo import load_url_image, get_image_url
from deepinv.optim.phase_retrieval import cosine_similarity, correct_global_phase, spectral_methods

In [13]:
def compute_norm(model,x0, max_iter=100, tol=1e-3, verbose=True):
    r"""
    Computes the spectral :math:`\ell_2` norm (Lipschitz constant) of the operator

    :math:`A^{\top}A`, i.e., :math:`\|A^{\top}A\|`.

    using the `power method <https://en.wikipedia.org/wiki/Power_iteration>`_.

    :param torch.Tensor x0: initialisation point of the algorithm
    :param int max_iter: maximum number of iterations
    :param float tol: relative variation criterion for convergence
    :param bool verbose: print information

    :returns z: (float) spectral norm of :math:`A^{\top}A`, i.e., :math:`\|A^{\top}A\|`.
    """
    x = torch.randn_like(x0)
    x /= torch.norm(x)
    zold = torch.zeros_like(x)
    print("zold",zold)
    for it in range(max_iter):
        y = model.A(x)
        y = model.A_adjoint(y)
        z = torch.matmul(x.reshape(-1), y.reshape(-1)) / torch.norm(x) ** 2
        #print("z",z)
        rel_var = torch.norm(z - zold)
        if rel_var < tol and verbose:
            print(
                f"Power iteration converged at iteration {it}, value={z.item():.2f}"
            )
            break
        zold = z
        x = y / torch.norm(y)

    return z

In [2]:
# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)

device = deepinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

url = get_image_url("CBSD_0010.png")
x = load_url_image(url, grayscale=False).to(device)

x = torch.tensor(x, device=device, dtype=torch.float)
x = torch.nn.functional.interpolate(x, size=(16, 16))

x = x.to(torch.cfloat)

  x = torch.tensor(x, device=device, dtype=torch.float)


In [16]:
physics = PseudoRandomPhaseRetrieval(2, (3, 16, 16),2)

torch.allclose(x - physics.B_adjoint(physics.B(x)),torch.tensor(0.0+0.0j))
physics(x).shape

torch.Size([1, 3, 20, 20])

In [18]:
physics2 = RandomPhaseRetrieval(5*3*16*16, (3, 16, 16))

In [178]:
a = torch.tensor([1,2,3,4,5,6])
torch.fft.fft(a)

tensor([21.+0.0000j, -3.+5.1962j, -3.+1.7321j, -3.+0.0000j, -3.-1.7321j,
        -3.-5.1962j])

In [3]:
x = torch.randn((1, 1, 3, 3), dtype=torch.float, requires_grad=True)
physics = deepinv.physics.CompressedSensing(
    m=10, img_shape=(1, 3, 3), dtype=torch.float
)
loss = L2()
func = lambda x: loss(x, torch.ones_like(physics(x)), physics)[0]
# grad_value = torch.func.grad(func)(x)
print(loss.grad(x, torch.ones_like(physics(x)), physics))
# assert torch.isclose(grad_value[0], jvp_value, rtol=1e-5).all()
print(torch.func.grad(func)(x))

tensor([[[[ 1.5126,  3.6288,  2.1499],
          [ 1.2216,  2.1887, -0.8486],
          [ 0.4221,  6.5647,  0.8041]]]], grad_fn=<ViewBackward0>)
tensor([[[[ 1.5126,  3.6288,  2.1499],
          [ 1.2216,  2.1887, -0.8486],
          [ 0.4221,  6.5647,  0.8041]]]], grad_fn=<ViewBackward0>)


In [94]:
x = torch.randn((1, 1, 3, 3), dtype=torch.cfloat, requires_grad=True)
physics = deepinv.physics.RandomPhaseRetrieval(
    m=10, img_shape=(1, 3, 3), dtype=torch.cfloat
)
print(x)
y = physics.A(x)
print(y)
print(physics.A_dagger(y))

tensor([[[[ 0.0797+1.0785j,  0.2736+1.5841j, -0.5817+0.4442j],
          [ 1.4804+0.0176j, -0.5149-0.3055j, -0.8389+0.5117j],
          [-1.6124-1.0875j, -1.7800+0.9745j,  0.4210+1.6172j]]]],
       requires_grad=True)
tensor([[6.1272, 0.5628, 3.1397, 2.6942, 4.7124, 0.2080, 3.1730, 0.1719, 3.9874,
         0.8332]], grad_fn=<PowBackward0>)
tensor([[[[nan+nanj, nan+nanj, nan+nanj],
          [nan+nanj, nan+nanj, nan+nanj],
          [nan+nanj, nan+nanj, nan+nanj]]]], grad_fn=<CloneBackward0>)


In [4]:
y = physics(x)
y.shape
physics.B_adjoint(y)

tensor([[[[ 0.8858-1.4252j,  0.4933-1.2174j, -1.4996+1.6350j],
          [-0.3503-0.1512j,  1.0030+0.2949j,  0.4901+0.2424j],
          [ 2.2457+0.3787j, -0.1436-2.0321j, -0.8641-0.9089j]]]],
       grad_fn=<ViewBackward0>)

In [11]:
hasattr(physics, 'B')

True

In [354]:
loss = AmplitudeLoss()
x = torch.randn((5, 1, 3, 3), dtype=torch.cfloat, requires_grad=True)
physics = deepinv.physics.CompressedSensing(
    m=10, img_shape=(1, 3, 3), dtype=torch.cfloat
)
loss(x, torch.ones_like(physics(x)), physics)

tensor([6.4934, 9.5337, 6.0647, 5.2444, 4.5400], grad_fn=<PowBackward0>)

In [20]:
x = torch.Tensor([1, 2, 3])
func = lambda x: x.sum()
grad_value = torch.func.grad(func)(x)
grad_value

tensor([1., 1., 1.])

In [6]:
physics = deepinv.physics.CompressedSensing(
    m=10, img_shape=(1, 3, 3), dtype=torch.cfloat
)
physics.compute_norm(torch.randn((1, 1, 3, 3),dtype=torch.cfloat))

Power iteration converged at iteration 12, value=2.86-0.00j


tensor(2.8610)

In [2]:
A = torch.eye(3, dtype=torch.float64)
def A_forward(v):
    return A @ v

physics = deepinv.physics.Physics(A=A_forward)

x = torch.tensor([1,2,3],dtype=torch.float64)
physics(x)

tensor([1., 2., 3.], dtype=torch.float64)

In [10]:
for _ in range(100):
    x = torch.randn(3, dtype=torch.float64)
    print("x",x)
    v = torch.randn(3, dtype=torch.float64)
    print("v",v)
    print("A_jvp",v)
    assert torch.allclose(physics.A_jvp(x, x), x)

x tensor([-0.0437, -2.2669, -1.1095], dtype=torch.float64)
v tensor([-0.4212,  0.5659,  0.7871], dtype=torch.float64)
A_jvp tensor([-0.4212,  0.5659,  0.7871], dtype=torch.float64)
x tensor([1.9426, 1.4612, 0.9582], dtype=torch.float64)
v tensor([ 0.4383, -1.1001, -0.6980], dtype=torch.float64)
A_jvp tensor([ 0.4383, -1.1001, -0.6980], dtype=torch.float64)
x tensor([-0.3678, -0.8955, -1.6981], dtype=torch.float64)
v tensor([ 0.2353, -2.0151, -1.0602], dtype=torch.float64)
A_jvp tensor([ 0.2353, -2.0151, -1.0602], dtype=torch.float64)
x tensor([-0.6248,  1.6312, -0.2137], dtype=torch.float64)
v tensor([-0.1252, -0.5810, -1.0189], dtype=torch.float64)
A_jvp tensor([-0.1252, -0.5810, -1.0189], dtype=torch.float64)
x tensor([-2.8340,  0.6608, -1.3493], dtype=torch.float64)
v tensor([0.8876, 0.0263, 1.3302], dtype=torch.float64)
A_jvp tensor([0.8876, 0.0263, 1.3302], dtype=torch.float64)
x tensor([ 0.4767, -1.5042,  0.2863], dtype=torch.float64)
v tensor([-0.1405, -0.9067, -0.9800], dtype=t

In [10]:
seed = torch.manual_seed(0) # Random seed for reproducibility
x = torch.randn((1, 1, 3, 3),dtype=torch.cfloat) # Define random 3x3 image
physics = RandomPhaseRetrieval(m=10,img_shape=(1, 3, 3))
physics(x)

tensor([[1.1901, 4.0743, 0.1858, 2.3197, 0.0734, 0.4557, 0.1231, 0.6597, 1.7768,
         0.3864]])

In [3]:
seed = torch.manual_seed(0) # Random seed for reproducibility
x = torch.randn((1, 1, 3, 3),dtype=torch.cfloat) # Define random 3x3 image
physics = RandomPhaseRetrieval(m=10, img_shape=(1, 3, 3))
physics(x)

tensor([[1.1901, 4.0743, 0.1858, 2.3197, 0.0734, 0.4557, 0.1231, 0.6597, 1.7768,
         0.3864]])

In [4]:
# assert A_jvp and autograd works the same
x = torch.randn((1, 1, 3, 3),dtype=torch.cfloat,requires_grad=True)
loss = AmplitudeLoss()
print(loss(x,torch.ones_like(physics(x)),physics))
grad_value = torch.autograd.grad(loss(x,torch.ones_like(physics(x)),physics),x)
jvp_value = loss.grad(x, torch.ones_like(physics(x)), physics)
torch.isclose(grad_value[0],jvp_value,rtol=1e-5)

tensor([2.2016], grad_fn=<PowBackward0>)


tensor([[[[True, True, True],
          [True, True, True],
          [True, True, True]]]])

In [740]:
x = torch.randn((1, 1, 3, 3), dtype=torch.cfloat, device='cpu', requires_grad=True)
physics = deepinv.physics.RandomPhaseRetrieval(m=10, img_shape=(1, 3, 3), device='cpu')
loss = L2()
grad_value = torch.autograd.grad(loss(x,torch.ones_like(physics(x)),physics),x)[0]
jvp_value = loss.grad(x, torch.ones_like(physics(x)), physics)
print(torch.isclose(grad_value[0],jvp_value,rtol=1e-5).all())

tensor(True)


In [10]:
a = torch.Tensor([1,2,3])
a.requires_grad = True
b = torch.sum(a**2)
torch.autograd.grad(b,a)

(tensor([2., 4., 6.]),)

In [9]:
u = randn_like(x)

Au = physics.B.A(u)

v = randn_like(Au)
Atv = physics.A_adjoint(v)

s1 = (v.conj() * Au).flatten().sum()

s2 = (Atv * u.conj()).flatten().sum()

print(s1,s2)

tensor(-2.9910+1.7972j) tensor(-2.9910-1.7972j)


In [10]:
A = lambda x: torch.roll(x, shifts=(1,1), dims=(2,3)) # shift image by one pixel
x = torch.randn((1, 1, 2, 2))
y = A(x)
print(y.shape)
A_adjoint = adjoint_function(A, x.shape)
print(x)
print(A_adjoint(y))
torch.allclose(A_adjoint(y), x) # we have A^T(A(x)) = x

torch.Size([1, 1, 2, 2])
tensor([[[[-0.3516, -1.3869],
          [ 0.2270,  0.8023]]]])
tensor([[[[-0.3516, -1.3869],
          [ 0.2270,  0.8023]]]])


True

In [11]:
img_size = (1,5,5)
model = RandomPhaseRetrieval(10,img_size)
x = torch.randn(img_size, device="cpu", dtype=torch.cfloat).unsqueeze(0)
y = torch.randn(img_size, device="cpu", dtype=torch.cfloat).unsqueeze(0)
#print(x.reshape(-1).shape)
#print(model.A_adjoint(model.A(x)))
#print(x.reshape(-1).shape)
#print(x.reshape(-1)@y.reshape(-1))
norm = compute_norm(model,x)
print(norm.abs())
model.forward(x)

zold tensor([[[[0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
          [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
          [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
          [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
          [0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j]]]])
Power iteration converged at iteration 12, value=0.20+0.94j
tensor(0.9611)


tensor([[27.6474,  2.0685,  2.0891,  0.8423,  1.1014,  5.0582,  2.2531,  3.1285,
          2.4926,  3.6849]])

In [7]:
x = torch.randn(img_size, device="cpu", dtype=torch.cfloat).unsqueeze(0)
x.shape
print(model.A(x).abs()**2)
print(model.forward(x))

tensor([[8.1145, 4.5697, 1.9492, 2.4946, 2.9993, 0.3669, 0.4354, 0.1007, 0.3566,
         4.3726]])
tensor([[8.1145, 4.5697, 1.9492, 2.4946, 2.9993, 0.3669, 0.4354, 0.1007, 0.3566,
         4.3726]])


In [8]:
a = torch.randn(2,2, dtype=torch.cfloat)
print(a)
print(a.T)
print(a.conj().T)
print(a @ a.T)
print(a @ a.conj().T)

tensor([[-0.0838-1.1806j, -0.6708-0.2329j],
        [-0.2557+0.1537j,  0.9800-0.2512j]])
tensor([[-0.0838-1.1806j, -0.2557+0.1537j],
        [-0.6708-0.2329j,  0.9800-0.2512j]])
tensor([[-0.0838+1.1806j, -0.2557-0.1537j],
        [-0.6708+0.2329j,  0.9800+0.2512j]])
tensor([[-0.9911+0.5104j, -0.5130+0.2292j],
        [-0.5130+0.2292j,  0.9390-0.5710j]])
tensor([[ 1.9050+0.0000j, -0.7588-0.0821j],
        [-0.7588+0.0821j,  1.1124+0.0000j]])


In [9]:
torch.linalg.pinv(a)

tensor([[-0.1951+0.7924j, -0.4214+0.3880j],
        [ 0.0117+0.2403j,  0.8712+0.3906j]])

In [10]:
p = torch.randn(2,2, dtype=torch.cfloat)
q = torch.randn(2,2, dtype=torch.cfloat)
print(p.reshape(-1))
print(q.reshape(-1))
print(torch.vdot(p.reshape(-1),q.reshape(-1)))

tensor([ 0.3328+0.4799j,  0.3259-0.3818j,  0.5162+0.5780j, -0.2479+0.8042j])
tensor([-0.6631+0.0180j, -0.2426+0.4473j, -0.7994+0.4581j,  1.1511-0.2296j])
tensor(-1.0798+0.2071j)


In [11]:
#  create a constant 2-by-2 complex matrix
a = torch.tensor([1+1j, 2+2j, 3+3j, 4+4j], dtype=torch.cfloat)
b = torch.tensor([16+6j, 2+6j, 3+5j, 4+8j], dtype=torch.cfloat)
c = a.flatten()
print(c)
d = b.flatten()
print(d)
print(torch.vdot(a,b))

tensor([1.+1.j, 2.+2.j, 3.+3.j, 4.+4.j])
tensor([16.+6.j,  2.+6.j,  3.+5.j,  4.+8.j])
tensor(110.+20.j)


In [12]:
torch.vdot(torch.tensor([2, 3]), torch.tensor([2, 1]))
a = torch.tensor((1 +2j, 3 - 1j))
b = torch.tensor((2 +1j, 4 - 0j))
torch.vdot(b, a)
torch.vdot(a, b)

tensor(16.+1.j)

# Evaluate the performance of spectral methods w.r.t. oversampling ratio

In [9]:
RANGE_M = 600
REPEATS = 30
IMG_SHAPE = (1, 8, 8)

avg_cosines = []
raw_cosines = torch.zeros((RANGE_M, REPEATS))

for m in tqdm(range(1,RANGE_M+1)):
    physics = RandomPhaseRetrieval(m=m, img_shape=IMG_SHAPE)
    cosines = []
    for i in range(REPEATS):
        x = torch.randn((1,) + IMG_SHAPE,dtype=torch.cfloat)
        y = physics(x)
        x_hat = spectral_methods(y,physics)
        x_hat = x_hat * torch.sqrt(y.sum())
        cosine = cosine_similarity(x,x_hat)
        cosines.append(cosine)
        raw_cosines[m-1, i] = cosines[-1]
    avg_cosines.append(sum(cosines)/len(cosines))
oversampling_ratio = [i/np.prod(IMG_SHAPE) for i in range(1,RANGE_M+1)]

100%|██████████| 600/600 [01:13<00:00,  8.17it/s]
