In [None]:
import cupy as cp
import matplotlib.pyplot as plt
from functions import R,RT,mshow_complex,mshow
from phantom import *

# Generate phantom (Doga)

In [None]:
#####################################################
# Create a test 3D vector field 
#####################################################

scale = 1
shape = (scale*64, scale*64, scale*64)
centers = np.array([
    (scale*32, scale*32, scale*32),  # Center of first circle
    (scale*24, scale*36, scale*32),  # Center of second circle
    (scale*42, scale*24, scale*32),  # Center of third circle
])
radii = np.array([
    scale*24,  # Radius of first circle
    scale*10,  # Radius of second circle
    scale*6,   # Radius of third circle
]) 
domains = np.array([
    (np.pi/6, np.pi/2),        # X direction
    (np.pi/6, -np.pi/2),  # Y direction  
    (np.pi/6, -np.pi/2),  # Z direction  
])

# Create the phantom
field, mask = create_vector_field_phantom_3d(
    shape, centers, radii, domain_angles=domains, transition_width=scale*0.0)

print (field.shape)

plt.quiver(field[:, :, 32, 0], field[:, :, 32, 1])
plt.show()  
plt.quiver(field[:, 32, :, 0], field[:,  32, :, 1])
plt.show()  
plt.quiver(field[25, :, :, 0], field[25,:, :, 1])
plt.show()  


In [None]:
w = cp.array(field)

## Tomography parameters

In [None]:
theta = cp.linspace(0,cp.pi,60,endpoint=True).astype('float32')
rotation_axis = w.shape[1]/2
w_shape = w.shape
data_shape = [w.shape[0],len(theta),w.shape[2]]


## Find mask

In [None]:
mask = cp.array(mask)[...,cp.newaxis] # already given 
mshow(mask[w_shape[0]//2,...,0],True)
# w_abs = np.linalg.norm(w,axis=-1)
# mshow(w_abs[w_shape[0]//2],True)

# mask = (w_abs>0).astype('float32')[...,cp.newaxis]


# Make operators

### $P_\theta(w) = \mathcal{R}_\theta (w_x) \cos\theta + \mathcal{R}_\theta (w_y) \sin\theta$

In [None]:
def P(w):
    Rw = cp.zeros(data_shape,dtype='float32')
    t = [-cp.cos(theta)[:,cp.newaxis],
        -cp.sin(theta)[:,cp.newaxis],
        0]###???    
    for k in range(3):
        Rw += t[k]*R(w[...,k].astype('complex64'),theta,rotation_axis).real
    return Rw
   
def Padj(Rw):
    t = [-cp.cos(theta)[:,cp.newaxis],
        -cp.sin(theta)[:,cp.newaxis],
        0]###???
    w = cp.zeros(w_shape,dtype='float32')
    for k in range(3):
        w[...,k] = RT((Rw*t[k]).astype('complex64'),theta,rotation_axis).real
    return w    

In [None]:

data=P(cp.array(field))
mshow(field[field.shape[0]//2,...,0],True)
mshow(field[field.shape[0]//2,...,1],True)
mshow(data[data.shape[0]//2],True)

# Adjoint test

In [None]:
a = cp.random.random(w_shape,dtype='float32')
b = cp.random.random(data_shape,dtype='float32')
bb = P(a)
aa = Padj(b)
print(cp.sum(aa*a))
print(cp.sum(bb*b))

# Data simulation

In [None]:
rotation_axis = data_shape[-1]//2
data = P(w)
mshow(data[data.shape[0]//2])
mshow(data[:,data.shape[1]//2])

### $F_1(w) = \sum_\theta\|P_\theta(M(w))-d_\theta\|_2^2$

In [None]:

def F1(w):
    return cp.linalg.norm(P(mask*w)-data)**2

def dF1(w,y):
    return 2*cp.sum((P(mask*w)-data)*P(mask*y))

def d2F1(w,y,z):
    return 2*cp.sum(P(mask*y)*P(mask*y))

def dF1adj(w,y):
    return 2*mask*Padj(P(mask*w)-data)*y

### Approximation test

In [None]:
w = cp.random.random(w_shape).astype('float32')
dw0 = cp.random.random(w_shape).astype('float32')/10

l = cp.linspace(0,0.1,20).astype('float32')
err1 = np.zeros(20)
err2 = np.zeros(20)
err3 = np.zeros(20)
for k in range(20):
    dw = l[k]*dw0
    a = F1(w+dw)
    err1[k] = cp.linalg.norm(F1(w)-a)
    err2[k] = cp.linalg.norm(F1(w)+dF1(w,dw)-a)    
    err3[k] = cp.linalg.norm(F1(w)+dF1(w,dw)+0.5*d2F1(w,dw,dw)-a)
plt.plot(err1,label='f')
plt.plot(err2,label='linear')
plt.plot(err3,label='quadr')
plt.legend()
plt.grid()
# plt.yscale('log')
plt.show()
# ss

### Adjoint test

In [None]:
w = cp.random.random(w_shape).astype('float32')
dw = cp.random.random(w_shape).astype('float32')/10

a = dF1(w,dw)
b = dF1adj(w,a)

print(cp.sum(a*a))
print(cp.sum(dw*b))

### $ F_{21}(w,v) = \|M(w)\cdot v\|_2^2$

In [None]:
def F21(x):
    w,v = x
    w*=mask
    wv = cp.sum(w*v,axis=-1)
    return cp.linalg.norm(wv)**2

def dF21(x,y):
    w,v = x
    dw,dv = y
    w*=mask
    dw*=mask
    t1 = cp.sum(w*v,axis=-1)
    t2 = cp.sum(dw*v,axis=-1)+cp.sum(w*dv,axis=-1)
    return 2*cp.sum(t1*t2)

def d2F21(x,y,z):
    w,v = x
    dw1,dv1 = y
    dw2,dv2 = z
    w*=mask
    dw1*=mask
    dw2*=mask
    
    wv = cp.sum(w*v,axis=-1)
    dw1v2 = cp.sum(dw1*dv2,axis=-1)
    dw2v1 = cp.sum(dw2*dv1,axis=-1)
    t1 = cp.sum(wv*(dw1v2+dw2v1))

    dw1v = cp.sum(dw1*v,axis=-1)
    wdv1 = cp.sum(w*dv1,axis=-1)
    dw2v = cp.sum(dw2*v,axis=-1)
    wdv2 = cp.sum(w*dv2,axis=-1)
    t2 = cp.sum((dw1v+wdv1)*(dw2v+wdv2))
    return 2*(t1+t2)

def dF21adj(x,y):
    w,v = x
    w*=mask
    wv = cp.sum(w*v,axis=-1)*y
    dw = 2*wv[...,cp.newaxis]*v
    dv = 2*cp.sum(wv[...,cp.newaxis]*w,axis=(0,1,2))
    dw*=mask
    return [dw,dv]



### Approximation test

In [None]:
w = cp.random.random(w_shape).astype('float32')
dw0 = cp.random.random(w_shape).astype('float32')/10
v = cp.random.random(3).astype('float32')
dv0 = cp.random.random(3).astype('float32')/10

l = cp.linspace(0,0.1,20).astype('float32')
err1 = np.zeros(20)
err2 = np.zeros(20)
err3 = np.zeros(20)
for k in range(20):
    dw = l[k]*dw0
    dv = l[k]*dv0
    a = F21([w+dw,v+dv])
    #print(cp.linalg.norm(dF21([w,v],[dw,0*dv])),cp.linalg.norm(dF21([w,v],[0*dw,dv])))
    err1[k] = cp.linalg.norm(F21([w,v])-a)
    err2[k] = cp.linalg.norm(F21([w,v])-dF21([w,v],[dw,dv])-a)    
    err3[k] = cp.linalg.norm(F21([w,v])+dF21([w,v],[dw,dv])+0.5*d2F21([w,v],[dw,dv],[dw,dv])-a)
plt.plot(err1,label='f')
plt.plot(err2,label='linear')
plt.plot(err3,label='quadr')
plt.grid()
# plt.yscale('log')
plt.show()

### ADjoint test for the gradient

In [None]:

w = cp.random.random(w_shape).astype('float32')
v = cp.random.random(3).astype('float32')
dw = cp.random.random(w_shape).astype('float32')/4
dv = cp.random.random(3).astype('float32')/4

a = dF21([w,v],[dw,dv])
b = dF21adj([w,v],a)

print(cp.sum(a*a))
print(cp.sum(dw*b[0])+cp.sum(dv*b[1]))

### $F_{22}(\phi) = [\cos(\phi_1) \sin(\phi_2), \sin(\phi_1) \sin(\phi_2), \cos(\phi_2)] $

In [None]:
def F22(x):
    w,phi = x    
    t1 = cp.cos(phi[0])*cp.sin(phi[1])
    t2 = cp.sin(phi[0])*cp.sin(phi[1])
    t3 = cp.cos(phi[1])    
    return [w,cp.array([t1,t2,t3])]

def dF22(x,y):
    w,phi = x
    dw,dphi = y
    mat = cp.array(
        [[-cp.sin(phi[0])*cp.sin(phi[1]),cp.cos(phi[0])*cp.sin(phi[1]), cp.array(0)],
         [cp.cos(phi[0])*cp.cos(phi[1]),cp.sin(phi[0])*cp.cos(phi[1]),-cp.sin(phi[1])]])
    t1 = mat[0,0]*dphi[0]+mat[1,0]*dphi[1]
    t2 = mat[0,1]*dphi[0]+mat[1,1]*dphi[1]
    t3 = mat[0,2]*dphi[0]+mat[1,2]*dphi[1]
    return [dw,cp.array([t1,t2,t3])]

def d2F22(x,y,z):
    w,phi = x
    dw1,dphi1 = y
    dw2,dphi2 = z
    
    #d22 = cp.array([-cp.cos(phi[0])*cp.sin(phi[1]),-cp.sin(phi[0])*cp.sin(phi[1]),-cp.cos(phi[1])])
    d11 = cp.array([-cp.cos(phi[0])*cp.sin(phi[1]),-cp.sin(phi[0])*cp.sin(phi[1]),cp.array(0)])
    d22 = cp.array([-cp.cos(phi[0])*cp.sin(phi[1]),-cp.sin(phi[0])*cp.sin(phi[1]),-cp.cos(phi[1])])
    d12 = cp.array([-cp.sin(phi[0])*cp.cos(phi[1]),cp.cos(phi[0])*cp.cos(phi[1]),cp.array(0)])
    res =  (dphi1[0]*dphi2[0]*d11+
            dphi1[1]*dphi2[1]*d22+
            (dphi1[0]*dphi2[1]+dphi1[1]*dphi2[0])*d12)
    return [0,res]
  
def dF22adj(x,y):
    w,phi = x
    dw,dv = y
    
    mat = cp.array(
        [[-cp.sin(phi[0])*cp.sin(phi[1]),cp.cos(phi[0])*cp.sin(phi[1]), cp.array(0)],
         [cp.cos(phi[0])*cp.cos(phi[1]),cp.sin(phi[0])*cp.cos(phi[1]),-cp.sin(phi[1])]])
    dphi = cp.zeros([2],dtype='float32')
    dphi[0] = mat[0,0]*dv[0] + mat[0,1]*dv[1] + mat[0,2]*dv[2]
    dphi[1] = mat[1,0]*dv[0] + mat[1,1]*dv[1] + mat[1,2]*dv[2]
    
    return [dw,dphi]


### Approximation test

In [None]:
  
w = cp.random.random(w.shape).astype('float32')*mask
dw0 = cp.random.random(w.shape).astype('float32')/10*mask
v = cp.random.random(3).astype('float32')
dv0 = cp.random.random(3).astype('float32')/10
phi = cp.random.random(2).astype('float32')
dphi0 = cp.random.random(2).astype('float32')/3

l = cp.linspace(0,1,20).astype('float32')
err1 = np.zeros(20)
err2 = np.zeros(20)
err3 = np.zeros(20)
for k in range(20):
    dphi = l[k]*dphi0
    dw = 0*dw0
    a = F22([w+dw,phi+dphi])    
    a1 = F22([w,phi])    
    err1[k] = cp.linalg.norm(a1[0]-a[0])**2+cp.linalg.norm(a1[1]-a[1])**2
    a2 = dF22([w,phi],[dw,dphi])
    err2[k] = cp.linalg.norm(a1[0]+a2[0]-a[0])**2+cp.linalg.norm(a1[1]+a2[1]-a[1])**2
    a3 = d2F22([w,phi],[dw,dphi],[dw,dphi])
    err3[k] = cp.linalg.norm(a1[0]+a2[0]+0.5*a3[0]-a[0])**2+cp.linalg.norm(a1[1]+a2[1]+0.5*a3[1]-a[1])**2    
plt.plot(err1,label='f')
plt.plot(err2,label='linear')
plt.plot(err3,label='quadr')
plt.legend()
plt.grid()
plt.yscale('log')
plt.show()

## Adjoint test for the gradient

In [None]:
w = cp.random.random(w_shape).astype('float32')
phi = cp.random.random(2).astype('float32')
dw = cp.random.random(w_shape).astype('float32')/4
dphi = cp.random.random(2).astype('float32')/4

a = dF22([w,phi],[dw,dphi])
b = dF22adj([w,phi], a)

print(cp.sum(a[0]*a[0])+cp.sum(a[1]*a[1]))
print(cp.sum(dw*b[0])+cp.sum(dphi*b[1]))


### $F_3(w,a) = \|M(|w|-a)\|_2^2$

In [None]:
def F3(x):
    w,a = x
    w_abs = cp.linalg.norm(w, axis=-1)
    return cp.linalg.norm(mask[...,0]*(w_abs-a))**2

def dF3(x,y):
    w,a = x
    dw,da = y
    
    w_abs = cp.linalg.norm(w, axis=-1)
    n_one = w_abs*0+1
    
    t1 = 2*cp.sum((w-a*w/(w_abs[...,cp.newaxis]+1e-7))*dw*mask)
    t2 = -2*cp.sum(w_abs*mask[...,0])*da
    t3 = 2*a*da*cp.linalg.norm(n_one*mask[...,0])**2
    return t1+t2+t3
    
def d2F3(x,y,z):
    w,a = x
    dw1,da1 = y
    dw2,da2 = z

    
    w_abs = cp.linalg.norm(w, axis=-1)
    n_one = w_abs*0+1
    w_hat = w/(w_abs[...,cp.newaxis]+1e-7)

    t1 = cp.linalg.norm(n_one*mask[...,0])**2*da1*da2
    t2 = -cp.sum(w_hat*dw1*mask)*da2
    t3 = -cp.sum(w_hat*dw2*mask)*da1
    
    dw1dw2 = cp.sum(dw1*dw2*mask,axis=-1)
    t4 = cp.sum((1-a/(w_abs+1e-7))*dw1dw2*mask[...,0])

    wdw1 = cp.sum(w_hat*dw1,axis=-1)
    wdw2 = cp.sum(w_hat*dw2,axis=-1)

    t5 = cp.sum(a/(w_abs+1e-7)*wdw1*wdw2*mask[...,0])

    return 2*(t1+t2+t3+t4+t5)

def dF3adj(x,y):
    w,a = x
    
    w_abs = cp.linalg.norm(w, axis=-1)
    n_one = w_abs*0+1
    dw = 2*(w-a*w/(w_abs[...,cp.newaxis]+1e-7)) * mask * y
    da = (-2*cp.sum(w_abs*mask[...,0]) + 2*a*cp.linalg.norm(n_one*mask[...,0])**2) * y

    return [dw,da]    


# Adjoin test

In [None]:
w = cp.random.random(w_shape).astype('float32')
a = cp.random.random(1).astype('float32')
dw = cp.random.random(w_shape).astype('float32')/4
da = cp.random.random(1).astype('float32')/4

t = dF3([w,a],[dw,da])
tt = dF3adj([w,a], t)

print(cp.sum(t*t))
print(cp.sum(dw*tt[0])+cp.sum(da*tt[1]))

# s

# Approximation test

In [None]:
  
w = cp.random.random(w.shape).astype('float32')
dw0 = cp.random.random(w.shape).astype('float32')/10
a = cp.random.random(1).astype('float32')
da0 = cp.random.random(1).astype('float32')/10

l = cp.linspace(0,1,20).astype('float32')
err1 = np.zeros(20)
err2 = np.zeros(20)
err3 = np.zeros(20)
for k in range(20):
    dw = l[k]*dw0
    da = l[k]*da0
    t = F3([w+dw,a+da])    
    t1 = F3([w,a])    
    err1[k] = cp.linalg.norm(t1-t)**2
    t2 = dF3([w,a],[dw,da])
    err2[k] = cp.linalg.norm(t1+t2-t)**2
    # print(cp.linalg.norm(dF3([w,v],[dw,0*da])),cp.linalg.norm(dF3([w,v],[0*dw,da])))
    t3 = d2F3([w,a],[dw,da],[dw,da])
    err3[k] = cp.linalg.norm(t1+t2+0.5*t3-t)**2
plt.plot(err1,label='f')
plt.plot(err2,label='linear')
plt.plot(err3,label='quadr')
plt.yscale('log')
plt.legend()
plt.grid()
plt.show()


## Gradients and hessians

In [None]:
def gradient(vars,lam,d):
    w,phi,a = vars['w'],vars['phi'],vars['a']

    # first term
    gw = dF1adj(w,1)
    
    # second part
    gw0,gphi = dF22adj([w,phi],dF21adj(F22([w,phi]),1))    

    # gphi[1] = 0
    
    # third part 
    gw1,ga = dF3adj([w,a],1)
        
    # result
    grads = {}
    grads['w'] = gw + lam[0]*gw0 + lam[1]*gw1
    grads['phi'] = lam[0]*gphi
    grads['a'] = lam[1]*ga
    return grads

def hessian(vars,grads,etas,lam):
    w,phi,a = vars['w'],vars['phi'],vars['a']
    dw1,dphi1,da1 = grads['w'],grads['phi'],grads['a']
    dw2,dphi2,da2 = etas['w'],etas['phi'],etas['a']

    # first term
    t1 = d2F1(w,dw1,dw2)
    
    # second term, cascade
    x = [w,phi]
    y = [dw1,dphi1]
    z = [dw2,dphi2]    
    t2 = d2F21(F22(x),dF22(x,y),dF22(x,z))+dF21(F22(x),d2F22(x,y,z))
    
    # third term
    t3 = d2F3([w,a],[dw1,da1],[dw2,da2])
    
    return t1+lam[0]*t2+lam[1]*t3


# debug functions

In [None]:
def calc_err(w,phi,a,lam):
    return F1(w)+lam[0]*(F21(F22([w,phi])))+lam[1]*F3([w,a])

def plot_debug(vars, etas, top, bottom, alpha, lam):
    """Checking second order approximation"""
    w, phi,a = vars['w'],vars['phi'],vars['a']
    weta, phieta, aeta = etas['w'],etas['phi'],etas['a']
    npp = 9
    errt = cp.zeros(npp * 2)
    errt2 = cp.zeros(npp * 2)
    for k in range(0, npp * 2):
        wt = w + (alpha * k / (npp - 1)) * weta
        phit = phi + (alpha * k / (npp - 1)) * phieta
        at = a + (alpha * k / (npp - 1)) * aeta
        errt[k] = calc_err(wt,phit,at,lam)
        
    t = alpha * (cp.arange(2 * npp)) / (npp - 1)
    errt2 = calc_err(w,phi,a,lam)
    errt2 = errt2 - top * t + 0.5 * bottom * t**2
    
    print(f'{phi=},{a=}')
    plt.plot(t.get(),errt.get(),".",label="real")
    plt.plot(t.get(),errt2.get(),".",label="approx")
    plt.legend()
    plt.grid()
    plt.show()

def mplot3(a):
    fig, axs = plt.subplots(1, 3, figsize=(8,2))
    for k in range(3):
        im = axs[k].imshow(a[...,k].get(), cmap="gray")
        fig.colorbar(im, fraction=0.046, pad=0.04)
    plt.show()

def vis_debug(vars):
    print(f'{vars['phi']=}')
    print(f'{vars['a']=}')
    # mshow_complex(vars['w'][:,:,w_shape[2]//2,0]+1j*vars['w'][:,:,w_shape[2]//2,1])
    # mshow_complex(vars['w'][:,w_shape[1]//2,:,0]+1j*vars['w'][:,w_shape[1]//2,:,1])    
    # mshow_complex(vars['w'][25,:,:,0]+1j*vars['w'][25,:,:,1])
        
    # mshow(np.linalg.norm(vars['w'][:,:,w.shape[2]//2],axis=-1),True)
    # mshow(np.linalg.norm(vars['w'][:,w.shape[1]//2],axis=-1),True)
    # mshow(np.linalg.norm(vars['w'][25],axis=-1),True)
    mplot3(vars['w'][:,:,w_shape[2]//2])
    mplot3(vars['w'][:,w_shape[1]//2])
    mplot3(vars['w'][25])
    mshow(np.linalg.norm(vars['w'][:,:,w.shape[2]//2],axis=-1),True)
    mshow(np.linalg.norm(vars['w'][:,w.shape[1]//2],axis=-1),True)
    mshow(np.linalg.norm(vars['w'][25],axis=-1),True)

    fig= plt.figure(figsize=(4,4))
    field = vars['w'].get()
    plt.quiver(field[:, :, 32, 0], field[:, :, 32, 1])
    plt.show()  
    fig= plt.figure(figsize=(4,4))
    plt.quiver(field[:, 32, :, 0], field[:,  32, :, 1])
    plt.show()  
    fig= plt.figure(figsize=(4,4))
    plt.quiver(field[25, :, :, 0], field[25,:, :, 1])
    plt.show()  



# BH 
### $\argmin_{w,\varphi,a} \|F_1(w)\|_2^2+\lambda_1\|F_{21}(F_{22}(w,\varphi))\|_2^2+\lambda_2\|F_{3}(w,a)\|_2^2   =\sum_\theta\|P_\theta(M(w))-d_\theta\|_2^2+\lambda_1\|M(w)\cdot \Theta(\varphi_1,\varphi_2)\|_2^2+\lambda_2\|M(|w|-a)\|_2^2$

In [None]:
def BH(vars, d, niter,lam):
    err = cp.zeros(niter)
    for i in range(niter):
        
        if i%64==0:
            err[i] = calc_err(vars['w'],vars['phi'],vars['a'],lam)
            print(i,f'err={err[i]}')
        grads = gradient(vars,lam,d)
        
        if i == 0:
            etas = {}
            etas['w'] = -grads['w']
            etas['phi'] = -grads['phi']
            etas['a'] = -grads['a']
        else:
            top = hessian(vars,grads,etas,lam)
            bottom = hessian(vars,etas,etas,lam)
            beta = top / bottom
            etas['w'] = etas['w'] * beta - grads['w']
            etas['phi'] = etas['phi'] * beta - grads['phi']
            etas['a'] = etas['a'] * beta - grads['a']            
            
        top = (-cp.sum(grads['w']*etas['w'])
               -cp.sum(grads['phi']*etas['phi'])
               -cp.sum(grads['a']*etas['a']))
        bottom = hessian(vars, etas, etas, lam)
        alpha = top / bottom
        # print(alpha,top,bottom)
        if i%128==0:
            plot_debug(vars,etas,top,bottom,alpha,lam)
        if i%128==0:
            vis_debug(vars)

        vars['w'] += alpha * etas['w']
        vars['phi'] += alpha * etas['phi']
        vars['a'] += alpha * etas['a']       
    return vars,err


In [None]:
# initial guess
# cp.random.seed()
vars = {}
vars['w'] = (cp.random.random(field.shape).astype('float32')-0.5)*2*mask# cp.array(field)*0#+1*mask
vars['phi'] = cp.array([cp.pi/6+cp.pi/2,cp.pi/2],dtype='float32')
vars['a'] = cp.float32(1)
lam = [0.0001,0.1]
niter = 102400
vars,err2 = BH(vars,data,niter,lam)



In [None]:
mshow_complex(vars['w'][w_shape[0]//2,:,:,0]+1j*vars['w'][w_shape[0]//2,:,:,1])
mshow_complex(vars['w'][:,:,w_shape[2]//2,0]+1j*vars['w'][:,:,w_shape[1]//2,1])
mshow(np.linalg.norm(vars['w'][w.shape[0]//2],axis=-1),True)

In [None]:
pw = P(vars['w'])

mshow(pw[:,0],True)
# print(data.shape)
mshow(data[:,0],True)
# mshow(data[:,1],True)
# mshow(data[32,:],True)

In [None]:
# # initial guess
# vars = {}
# vars['w'] = w*0
# vars['phi'] = cp.array([1,1],dtype='float32')
# vars['a'] = cp.float32(0)
# lam = [0.0000,0.0]
# niter = 1024
# vars,err2 = BH(vars,data,niter,lam)

# mshow_complex(vars['w'][w_shape[0]//2,:,:,0]+1j*vars['w'][w_shape[0]//2,:,:,1])
# mshow_complex(vars['w'][:,w_shape[1]//2,:,0]+1j*vars['w'][:,w_shape[1]//2,:,1])
# mshow(np.linalg.norm(vars['w'][w.shape[0]//2],axis=-1),True)

In [None]:
# vars['phi']