In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np



In [48]:
def n_CNN(i_chan,
          o_chan,
          n_layers = 1,
          n_chan = 64,
          kernel_size = 3,
          pad = 1,
          stride = 1,
          dilation = 1,
          bias = True,
          bn = True,
          activation = nn.Tanh()
          ):
    # params prep
    if isinstance(n_chan,list):
        if len(n_chan) == 1:
            n_chan = n_chan * (n_layers - 1)
    elif isinstance(n_chan,int):
        n_chan = [n_chan] * (n_layers - 1)        
    
    # first layer
    cnn3d = [];
    if n_layers == 1:
        layer_1 = nn.Conv3d(i_chan,o_chan,kernel_size = kernel_size,stride = stride,
                      bias = bias, dilation = dilation, padding = pad)
        cnn3d.append(layer_1)
        if bn :
            cnn3d.append(nn.BatchNorm3d(o_chan))
        cnn3d.append(activation)
        return nn.Sequential(*cnn3d)
    else:
        layer_1 = nn.Conv3d(i_chan,n_chan[0],kernel_size = kernel_size,stride = stride,
                      bias = bias, dilation = dilation, padding = pad)
        cnn3d.append(layer_1)
        if bn :
            cnn3d.append(nn.BatchNorm3d(n_chan))
        cnn3d.append(activation)
            
    # mid layers

    for i in range(n_layers-2):
        layer_2 = nn.Conv3d(n_chan[i],n_chan[i+1],kernel_size = kernel_size,stride = stride,
                  bias = bias, dilation = dilation, padding = pad)
        cnn3d.append(layer_2)
        if bn :
            cnn3d.append(nn.BatchNorm3d(n_chan[i+1]))
        cnn3d.append(activation)
            
    # final layer
    layer_3 = nn.Conv3d(n_chan[-1],o_chan,kernel_size = kernel_size,stride = stride,
                      bias = bias, dilation = dilation, padding = pad)
    cnn3d.append(layer_3)
    if bn :
        cnn3d.append(nn.BatchNorm3d(o_chan))
    cnn3d.append(activation)    
    
    return nn.Sequential(*cnn3d)

In [62]:
class Net_cnn(nn.Module):
    def __init__(self):
        super(Net_cnn, self).__init__()
        self.conv1 = nn.Conv3d(1,20,5)
        self.conv1_p = nn.Conv3d(1,20,5)
        self.norm1 = nn.BatchNorm3d(20)
        self.conv2 = nn.Conv3d(20,40,5)
        self.norm2 = nn.BatchNorm3d(40)
        self.conv3 = nn.Conv3d(40,20,3)
        self.norm3 = nn.BatchNorm3d(20)
        self.conv4 = nn.Conv3d(20,1,3)
        self.norm4 = nn.BatchNorm3d(1)
        
    def forward(self, x):
        out1 = F.tanh(self.norm1(self.conv1(x)))
        out2 = F.tanhshrink(self.norm1(self.conv1_p(x)))
        out = out1 + out2
        out = F.tanhshrink(self.norm2(self.conv2(out)))
        out = F.tanhshrink(self.norm3(self.conv3(out)))
        out = F.tanhshrink(self.norm4(self.conv4(out)))
        return out

In [63]:
from FFT import fft

Phase = np.random.randn(15,15,15)
A = fft.fft(shape=1,axes=(2,3,4))
field_of_view = Phase.shape
yy, xx, zz = np.meshgrid(np.arange(0, Phase.shape[1]),
                         np.arange(0, Phase.shape[0]),
                         np.arange(0, Phase.shape[2]))
xx, yy, zz = ((xx - np.round((Phase.shape[0])/2)) / field_of_view[0],
              (yy - np.round((Phase.shape[1])/2)) / field_of_view[1],
              (zz - np.round((Phase.shape[2])/2)) / field_of_view[2])
k2 = xx**2 + yy**2 + zz**2 + np.spacing(1)
k2 = np.square(xx) + np.square(yy) + np.square(zz)+ np.spacing(1)
k2 = k2[None,None,:]
ik2 = 1/k2
k2.shape



(1, 1, 15, 15, 15)

In [76]:

net = Net_cnn()
Iter = 100
optimizer = optim.Adam(net.parameters(), lr=0.01)
optimizer.zero_grad()
criteria = nn.MSELoss()

for _ in range(Iter):
    Phase = np.random.randn(30,1,15,15,15)
    LP = A.IFT(k2*A.FT(Phase)).real
    data, target = Variable(torch.FloatTensor(LP)), Variable(torch.FloatTensor(Phase[:,:,6:-6,6:-6,6:-6]))
    output = net(data)
    loss = criteria(output, target)
    loss.backward()
    optimizer.step()
    print(loss)

Variable containing:
 1.0692
[torch.FloatTensor of size 1]

Variable containing:
 1.0262
[torch.FloatTensor of size 1]

Variable containing:
 1.0308
[torch.FloatTensor of size 1]

Variable containing:
 0.9757
[torch.FloatTensor of size 1]

Variable containing:
 1.0023
[torch.FloatTensor of size 1]

Variable containing:
 0.9278
[torch.FloatTensor of size 1]

Variable containing:
 0.9123
[torch.FloatTensor of size 1]

Variable containing:
 0.9464
[torch.FloatTensor of size 1]

Variable containing:
 0.9492
[torch.FloatTensor of size 1]

Variable containing:
 0.8630
[torch.FloatTensor of size 1]

Variable containing:
 0.9148
[torch.FloatTensor of size 1]

Variable containing:
 0.8656
[torch.FloatTensor of size 1]

Variable containing:
 0.8707
[torch.FloatTensor of size 1]

Variable containing:
 0.8676
[torch.FloatTensor of size 1]

Variable containing:
 0.8332
[torch.FloatTensor of size 1]

Variable containing:
 0.8405
[torch.FloatTensor of size 1]

Variable containing:
 0.8216
[torch.Floa

In [52]:
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ScipyConv2dFunction(Function):

    def forward(self, input, filter):
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        self.save_for_backward(input, filter)
        return torch.FloatTensor(result)

    def backward(self, grad_output):
        input, filter = self.saved_tensors
        grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
        grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')
        return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)


class ScipyConv2d(Module):

    def __init__(self, kh, kw):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(kh, kw))

    def forward(self, input):
        return ScipyConv2dFunction()(input, self.filter)

array([2, 3, 4, 5, 6, 7])

In [91]:
p = range(5)
p[-1]

4

In [14]:
import NUFFT
from NUFFT import kb128

# kernel & grid def
width = 3
J_c = (J+1)//2
grid_r = np.array([[-64,64],[-64,64],[-64,64]])

kb_kernel = kb128.kb128
kb_kernel = np.array(kb_kernel[0:J])
traj = np.repeat(np.array([[1,2,3,4,5]]),3,axis=0)
traj = np.reshape(traj,[3,-1])
samples = traj.shape[1]

kx = traj[0,:]
ky = traj[1,:]
kz = traj[2,:]
wx_t = kx - np.floor(kx)
wy_t = ky - np.floor(ky)
wz_t = kz - np.floor(kz)

w = np.zeros([3,traj.shape[1]],width)
def KB_3d(grid, kb_table,width):
    # grid[3,N] kb_table[128]
    # low accuracy
    k
    

for i in range(samples):
    ind_x = np.array([np.round(np.maximum(kx[i]-width,grid_r[0,0])):np.floor(np.minimum(kx[i]+width,grid_r[0,1]))])
    ind_y = np.array([np.round(np.maximum(ky[i]-width,grid_r[1,0])):np.floor(np.minimum(ky[i]+width,grid_r[1,1]))])
    ind_z = np.array([np.round(np.maximum(kz[i]-width,grid_r[2,0])):np.floor(np.minimum(kz[i]+width,grid_r[2,1]))])
    kgrid_y,kgrid_x,kgrid_z = np.meshgrid(ind_y,ind_x,ind_z)
    kgrid = np.concatenate(1,kgrid_x.flatten(),kgrid_y.flatten(),kgrid_z.flatten())
    weight = KB_3d(
    kernel = value[kgrid_x.reshape(-1),kgrid_y.reshape(-1),kgrid_z.reshape(-1)]
    K_kernel = grid_r

In [24]:
J = 5
t = np.array([[0,1,1],[2,4,5]])
np.round(1.4)

1.0