In [1]:
from __init__ import * 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
img =  plt.imread('images/mra_4.jpeg')[40:-25, 147:-206, 0] / 255
device = 'cpu'
mask = torch.ones((1, 1, 1024, 1024))
mri = MRI(mask, device)
sam = Sampler(device)
prox_htv_ = prox_htv(device)
htv = prox_htv_.htv
tv_l2_  = TV_l2(device)

In [3]:
c_2 = F.pad(torch.from_numpy(img)[None, None, :, :].double().to(device), (1, 1, 1, 1))
c_1 = sam.upsample(c_2)
c_half = sam.upsample(c_1)
c_quad = sam.upsample(c_half)

print(c_2.size(), c_1.size(), c_half.size(), c_quad.size())

torch.Size([1, 1, 513, 513]) torch.Size([1, 1, 1025, 1025]) torch.Size([1, 1, 2049, 2049]) torch.Size([1, 1, 4097, 4097])


In [4]:
c_2_in = c_2[:, :, :-1, :-1]
c_1_in = c_1[:, :, :-1, :-1]
c_half_in = c_half[:, :, :-1, :-1]
c_quad_in = c_quad[:, :, :-1, :-1]

print(c_2_in.size(), c_1_in.size(), c_half_in.size(), c_quad_in.size())

torch.Size([1, 1, 512, 512]) torch.Size([1, 1, 1024, 1024]) torch.Size([1, 1, 2048, 2048]) torch.Size([1, 1, 4096, 4096])


In [5]:
# check forward 
mri.set_h(1)
f_hat_1 = mri.H(c_1_in)
mri.set_h(2)
f_hat_2 = mri.H(c_2_in)
mri.set_h(0.5)
f_hat_half = mri.H(c_half_in)
mri.set_h(0.25)
f_hat_quad = mri.H(c_quad_in)

assert (torch.mean(torch.abs(f_hat_1-f_hat_2))) < 1e-5
assert (torch.mean(torch.abs(f_hat_1-f_hat_half))) < 1e-5
assert (torch.mean(torch.abs(f_hat_2-f_hat_half))) < 1e-5
assert (torch.mean(torch.abs(f_hat_quad-f_hat_half))) < 1e-5

In [6]:
# check adjoint

mri.set_h(1)
x = torch.normal(0, 1, (1, 1, 1024, 1024)).double() 
y_prime = torch.normal(0, 1, (1, 1, 1024, 1024)).double() 
y = mri.H(y_prime)

Hx = mri.H(x)
Hty = mri.Ht(y)

Hxy = (torch.conj(Hx) * y).sum()
xHty = (x * Hty).sum()

print(Hxy, xHty)

mri.set_h(2)
x = torch.normal(0, 1, (1, 1, 512, 512)).double() 
y_prime = torch.normal(0, 1, (1, 1, 512, 512)).double() 
y = mri.H(y_prime)

Hx = mri.H(x)
Hty = mri.Ht(y)

Hxy = (torch.conj(Hx) * y).sum()
xHty = (x * torch.conj(Hty)).sum()

print(Hxy, xHty)


mri.set_h(0.5)
x = torch.normal(0, 1, (1, 1, 2048, 2048)).double() 
y_prime = torch.normal(0, 1, (1, 1, 2048, 2048)).double() 
y = mri.H(y_prime)

Hx = mri.H(x)
Hty = mri.Ht(y)

Hxy = (torch.conj(Hx) * y).sum()
xHty = (x * Hty).sum()

print(Hxy, xHty)

mri.set_h(0.25)
x = torch.normal(0, 1, (1, 1, 4096, 4096)).double() 
y_prime = torch.normal(0, 1, (1, 1, 4096, 4096)).double() 
y = mri.H(y_prime)

Hx = mri.H(x)
Hty = mri.Ht(y)

Hxy = (torch.conj(Hx) * y).sum()
xHty = (x * Hty).sum()

print(Hxy, xHty)

tensor(156.6114+6.3962j, dtype=torch.complex128) tensor(156.6114, dtype=torch.float64)
tensor(-1932.1724+2.2737e-13j, dtype=torch.complex128) tensor(-1932.1724, dtype=torch.float64)
tensor(-359.8729-1.7334j, dtype=torch.complex128) tensor(-359.8729, dtype=torch.float64)
tensor(-2.0489-1.6078j, dtype=torch.complex128) tensor(-2.0489, dtype=torch.float64)


In [7]:
# check htv 
htv_2 = htv.L(c_2).abs().sum()
htv_1 = htv.L(c_1).abs().sum()
htv_half = htv.L(c_half).abs().sum()
htv_quad = htv.L(c_quad).abs().sum()

print(htv_1, htv_2, htv_half, htv_quad)

tensor(9328.5490, dtype=torch.float64) tensor(9328.5490, dtype=torch.float64) tensor(9328.5490, dtype=torch.float64) tensor(9328.5490, dtype=torch.float64)


In [8]:
# check tv l2
tv_2 = torch.norm(tv_l2_.L(c_2, 2), p=2, dim=1).sum()
tv_1 = torch.norm(tv_l2_.L(c_1, 1), p=2, dim=1).sum()
tv_half = torch.norm(tv_l2_.L(c_half, 0.5), p=2, dim=1).sum()
tv_quad = torch.norm(tv_l2_.L(c_quad, 0.25), p=2, dim=1).sum()

print(tv_1, tv_2, tv_half, tv_quad)

tensor(14242.7321, dtype=torch.float64) tensor(14242.7321, dtype=torch.float64) tensor(14242.7321, dtype=torch.float64) tensor(14242.7321, dtype=torch.float64)
