In [1]:
import torch
import torch.nn as nn
from models.meshfreeSR import UNet3D,meshfreeSR
from models.ndInterp import NDLinearInterpolation
import numpy as np
import pandas as pd

def GaussianRing(grid,radius,sigma):
	r = np.sqrt(grid[0]**2+grid[1]**2+grid[2]**2)
	return 1./(sigma*np.sqrt(2*3.1415))*np.exp(-1/2*((r-radius)/sigma)**2)

def GaussianRing2D(grid,radius,sigma):
	r = np.sqrt(grid[0]**2+grid[1]**2)
	return 1./(sigma*np.sqrt(2*3.1415))*np.exp(-1/2*((r-radius)/sigma)**2)

def Gaussian3D(grid,sigma):
	dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(3).to(device),(torch.eye(3)*sigma).to(device))
	return dist.log_prob(grid.to(device)).to(device)

def genSlice(grid,time,n_grid=32):
	return Gaussian3D(grid.view(3,-1).T,time).view(n_grid,n_grid,n_grid).unsqueeze(0)



In [2]:
device = 'cpu'

n_dim = 4
n_slice = 5
n_grid = 32
n_layers = 3
linear_size = 256

space_axis = torch.linspace(-5,5,n_grid).to(device)
time_axis = torch.linspace(0.5,2.5,n_slice).to(device)
grid = torch.stack(torch.meshgrid(space_axis,space_axis,space_axis)).to(device)
train_grid = torch.stack(
    torch.meshgrid(time_axis,space_axis,space_axis,space_axis)
).to(device)
train_context = torch.stack(
    [torch.tensor(GaussianRing(grid.cpu(),i,i/2)).unsqueeze(0) for i in time_axis]
).to(device)
train_loc = torch.cat((torch.ones((n_grid**3,1)).to(device)*0.75,grid.reshape(3,-1).T),axis=1)
train_loc = torch.cat((train_loc,torch.cat((torch.ones((n_grid**3,1)).to(device)*1.25,grid.reshape(3,-1).T),axis=1).to(device)),axis=0)
train_loc = torch.cat((train_loc,torch.cat((torch.ones((n_grid**3,1)).to(device)*1.75,grid.reshape(3,-1).T),axis=1).to(device)),axis=0)
train_loc = train_loc*0.99
train_value = GaussianRing(grid.cpu(),0.75,0.75/2).reshape(-1).unsqueeze(-1).to(device)
train_value = torch.cat((train_value,GaussianRing(grid.cpu(),1.25,1.25/2).reshape(-1).unsqueeze(-1).to(device)),axis=0)
train_value = torch.cat((train_value,GaussianRing(grid.cpu(),1.75,1.75/2).reshape(-1).unsqueeze(-1).to(device)),axis=0)

train_value = train_value[:,0]

xmin = torch.tensor([0.5,-5,-5,-5]).float().to(device)
xmax = torch.tensor([2.5,5,5,5]).float().to(device)


  app.launch_new_instance()


In [43]:
time_axis

tensor([0.5000, 1.0000, 1.5000, 2.0000, 2.5000])

In [67]:
grid.reshape(3,-1).shape

torch.Size([3, 32768])

In [10]:
i = 0.5
GaussianRing2D(grid.cpu(), i, i/2).shape

torch.Size([32, 32])

In [53]:
n_grid = 40
n_slice = 20

space_axis = torch.linspace(-5,5,n_grid)
time_axis = torch.linspace(0.5,2.5,n_slice)
grid = torch.stack(torch.meshgrid(space_axis,space_axis))
i = 1.5
g = GaussianRing2D(grid, i, i/2)

In [54]:
points = []

for i in range(n_grid):
    for j in range(n_grid):
        for k in range(n_grid):
            x = grid[0, i, j].item()
            y = grid[1, i, j].item()
            z = g[i, j].item()
            
            points.append((x, y, z))

df = pd.DataFrame(points, columns =['x', 'y', 'z']) 

In [56]:
import plotly.express as px
fig = px.scatter_3d(df, x='x', y='y', z='z')
fig.update_traces(marker=dict(size=2,
                              line=dict(width=2,
                                        color='DarkSlateGrey')),
                  selector=dict(mode='markers'))

fig.show()

In [3]:
mfSR = meshfreeSR(1,16,n_layers,n_dim,linear_size).to(device)
mfSR.train()
optimizer = torch.optim.Adam(mfSR.parameters(), lr=1e-3)
lossFunction = nn.MSELoss()
for i in range(100):
    optimizer.zero_grad()
    output = mfSR(train_context,train_loc,xmin,xmax)
    loss = lossFunction(output.T[0],train_value)
    print(loss)
    loss.backward()
    optimizer.step()

AssertionError: Torch not compiled with CUDA enabled

In [13]:
class LinearBlock(nn.Module):
	def __init__(self,in_channel,out_channel):
		super(LinearBlock,self).__init__()
		self.Linear = nn.Linear(in_channel,out_channel)
#		self.Act = Swish()
		self.Act_3c = nn.ReLU()
	def forward(self,x):
		x = self.Linear(x)
		x = self.Act_3c(x)
		return x

class ConvBlock(nn.Module):
	def __init__(self,in_channel,out_channel,kernel_size,padding_mode='replicate'):
		super(ConvBlock,self).__init__()
		self.Conv = nn.Conv3d(in_channel,out_channel,kernel_size,padding=kernel_size//2,padding_mode=padding_mode)
		self.BatchNorm = nn.BatchNorm3d(out_channel)
#		self.Act = Swish()
		self.Act_3c = nn.ReLU()

	def forward(self,x):
		x = self.Conv(x)
		x = self.BatchNorm(x)
		x = self.Act_3c(x)
		return x

class ResBlock(nn.Module):
	def __init__(self,in_channel,out_channel,padding_mode='replicate'):
		super(ResBlock,self).__init__()
		self.shortcut = nn.Conv3d(in_channel,out_channel,1)
		self.Conv_1 = ConvBlock(in_channel,in_channel,1,padding_mode=padding_mode)
		self.Conv_2 = ConvBlock(in_channel,in_channel,3,padding_mode=padding_mode)
		self.Conv_3a = nn.Conv3d(in_channel,out_channel,1)
		self.BatchNorm_3b = nn.BatchNorm3d(out_channel)
#		self.Act_3c = Swish()
		self.Act_3c = nn.ReLU()

	def forward(self,x):
		y = self.Conv_1(x)
		y = self.Conv_2(y)
		y = self.Conv_3a(y)
		y = self.BatchNorm_3b(y)
		y = self.shortcut(x)+y
		y = self.Act_3c(y)
		return y 

class SamplingBlock(nn.Module):	
	def __init__(self,in_channel,out_channel,mode,padding_mode='replicate'):
		super(SamplingBlock,self).__init__()	
		if mode == 'Up':
			self.conv = nn.Sequential(*[nn.Upsample(scale_factor=2),ResBlock(in_channel,out_channel)])
		if mode == 'Down':
			self.conv = nn.Sequential(*[ResBlock(in_channel,out_channel),nn.MaxPool3d(2)])
	
	def forward(self,x):
			return self.conv(x)

class MyUNet(nn.Module):
	def __init__(self,in_channel,out_channel,n_pairs,padding_mode='replicate'):
		super(MyUNet,self).__init__()	
		self.Res1 = ResBlock(in_channel,out_channel)
		self.n_pairs = n_pairs
		self.Down_array = nn.Sequential(*[SamplingBlock(2**i*out_channel,2**(i+1)*out_channel,'Down') for i in range(n_pairs)])
		self.Up_array = nn.Sequential(SamplingBlock(2**n_pairs*out_channel,2**(n_pairs-1)*out_channel,'Up'),*[SamplingBlock(2**i*out_channel,2**(i-2)*out_channel,'Up') for i in range(n_pairs,1,-1)])
		
	def forward(self,x):
		y = self.Res1(x)
		temp_array = []
		for i in range(self.n_pairs):
			temp_array.append(y)
			y = self.Down_array[i](y)
		temp_array.reverse()
		for i in range(self.n_pairs):
			y = self.Up_array[i](y)
			y = torch.cat((y,temp_array[i]),1)
		return y

In [46]:
train_context.shape

torch.Size([5, 1, 32, 32, 32])

In [52]:
from torchsummary import summary

model = UNet3D(1, 16, n_layers)
model(train_context).shape
# summary(model, (1, 32, 32, 32))

torch.Size([5, 32, 32, 32, 32])

In [53]:
context_grid = MyUNet(1, 16, n_layers)(train_context).permute(1,0,2,3,4)
context_vector = NDLinearInterpolation(context_grid, train_loc, xmin, xmax, device = 'cpu')
context_vector.shape

torch.Size([98304, 32])

In [38]:
combine = torch.cat((context_vector, train_loc),axis=1)

In [39]:
combine.shape

torch.Size([98304, 36])

In [41]:
x = LinearBlock(2 * 16 + 4, 32)(combine)


In [42]:
nn.Linear(32, 1)(x).shape

torch.Size([98304, 1])

In [None]:
mfSR = meshfreeSR(1,16,n_layers,n_dim,linear_size).to(device)
mfSR.train()
optimizer = torch.optim.Adam(mfSR.parameters(), lr=1e-3)
lossFunction = nn.MSELoss()
for i in range(100):
	optimizer.zero_grad()
	output = mfSR(train_context,train_loc,xmin,xmax)
	loss = lossFunction(output.T[0],train_value)
	print(loss)
	loss.backward()
	optimizer.step()


recon_step = 20
recon_ngrid = 64
recon_time = torch.linspace(0.6,2.4,recon_step)
recon_axis = torch.linspace(-4.9,4.9,recon_ngrid)
recon_point = torch.stack(torch.meshgrid(recon_time,recon_axis,recon_axis,recon_axis)).reshape(4,-1).T.cuda()

temp_time = time.time()
for i in range(100):
	with torch.no_grad():
		recon_result = mfSR(train_context,recon_point,xmin,xmax)

print(temp_time-time.time())

true_grid = torch.stack(torch.meshgrid(recon_axis,recon_axis,recon_axis))
true_result = GaussianRing(true_grid,0.6,0.6/2).to(device).reshape(1,-1)
for i in range(1,recon_step):
	true_result = torch.cat((true_result,GaussianRing(true_grid,recon_time[i],recon_time[i]/2).to(device).reshape(1,-1)),axis=0)


db = {'recon_result':recon_result,'recon_point':recon_point,'train_context':train_context,'train_grid':train_grid,'train_value':train_value,'true_result':true_result}


#torch.save(db,'GaussianRing_5context_3training.torchdb')
