In [29]:
import numpy as np
import torch
from cnf.inference_function import decoder
from cnf.utils.normalize import Normalizer_masked
from cnf.nf_networks import SIRENAutodecoder_film
import pyvista as pv

In [146]:
# Set simulation id
sim = 0
Nt = 40
Ng = 20

# Set filepaths
case = 'NoisyPipe'
training_path = 'training/' + case + '/'
data_path = 'data/' + case + '/data.npy'
coords_path = 'data/' + case + '/coords.npy'
coords_o_path = 'data/' + case + '/coords_o.npy'
sdf_path = 'data/' + case + '/sdf.npy'
sdf_o_path = 'data/' + case + '/sdf_o.npy'

# Get device
if torch.cuda.is_available():
  dev = "cuda"
else:
  dev = "cpu"

device = torch.device(dev)
print(device)

# Load latents
ckpt = torch.load(training_path + 'checkpoint_9999.pt')
latents = ckpt['hidden_states']['latents']
# latents = latents.cpu().numpy()
z = latents[sim*Nt:sim*Nt+Nt,:]

# Load model
model = SIRENAutodecoder_film(4,Nt,4,15,Nt) 
model.load_state_dict(ckpt['model_state_dict'])
model.eval() 
model.to(device)

# Load original data
data = np.load(data_path)
data = data[sim*Nt:sim*Nt+Nt,:,:]
sdf = np.load(sdf_path)
sdf = sdf[sim,:,:]
sdf_t = torch.tensor(sdf, device = device).float()
sdf_o = np.load(sdf_o_path)
sdf_o = sdf_o[sim,:,:]
coords = np.load(coords_path)
coords_t = torch.tensor(coords, device = device).float()
coords_o = np.load(coords_o_path)

# Assemble input
input = torch.cat([coords_t, sdf_t], axis=1)
input.size()[-1]

# Load normalizers
norm_params = torch.load('training/NoisyPipe/normalizer_params.pt')
x_norm = norm_params['x_normalizer_params']
y_norm = norm_params['y_normalizer_params']
x_normalizer = None
y_normalizer = Normalizer_masked(method='-11',dim=0,params=y_norm,sdf=sdf_t.reshape((1,-1,1)),N_samp=Nt*Ng,N_chan=4)


cuda


In [147]:
out = decoder(input,z,model,x_normalizer,y_normalizer,Nt,device)
out.shape

torch.Size([40, 320000, 4])

In [166]:
tidx = 20
data_d = out[tidx,:,:].cpu().numpy() # decoded latent
data_o = data[tidx,:,:] # original
# DEBUG
# data_d = torch.load('output_debug.pt').cpu().detach().numpy()
# data_d = data_d[0,:,:]
# sdf = torch.load('sdf_debug.pt').cpu().detach().numpy()
# sdf = sdf[0,:,:]
# sdf2 = np.load(sdf_path)
# sdf2 = sdf2[34,:,:]

# Reshape for StructuredGrid
mesh_d = pv.StructuredGrid()
mesh_d.points = coords_o
mesh_d.dimensions = (200, 40, 40)  
mesh_o = mesh_d.copy()

# Add scalar field
mesh_o['u'] = data_o[:,0]
mesh_o['v'] = data_o[:,1]
mesh_o['w'] = data_o[:,2]
mesh_o['p'] = data_o[:,3]
# mesh_o['p'] = mesh_o['p']/np.maximum(mesh_o['p'],5e4)
mesh_o['g'] = sdf_o
mesh_o['m'] = sdf_o>0
mesh_o = mesh_o.clip('y')

mesh_d['g'] = sdf_o
mesh_d['m'] = sdf_o>0
mesh_d['u'] = data_d[:,0]
mesh_d['v'] = data_d[:,1]
mesh_d['w'] = data_d[:,2]
mesh_d['p'] = data_d[:,3]
for scalar in mesh_d.array_names:
    if scalar != 'g' or scalar !='m':
        mesh_d[scalar] = (mesh_d[scalar]/np.max(np.abs(mesh_o[scalar])))#*mesh_d['m'] 
        mesh_o[scalar] = mesh_o[scalar]/np.max(np.abs(mesh_o[scalar]))#*mesh_o['m'] 
        # mesh_d[scalar] = mesh_d[scalar]*mesh_d['m'] 
mesh_d = mesh_d.clip('y')

pmin = np.min(mesh_o['p'][mesh_o['p']!=0])
pmax = np.max(mesh_o['p'][mesh_o['p']!=0])

# Define scalar and value ranges
scalar = 'u'
clim_dict = {
    'u': (-1,1),
    'v': (-1,1),
    'w': (-1,1),
    'p': (pmin,pmax),
    'g': (-1,1),
    'm': (0,1),
}

# Visualize
pl = pv.Plotter(shape=(1,2))
pl.subplot(0,0)
pl.add_mesh(mesh_o, scalars=scalar,
            clim=clim_dict[scalar],
            # show_scalar_bar=False
            # show_edges=True
            )
pl.camera_position='zx'
pl.show_axes()
# pl.add_scalar_bar()
pl.add_title('Original')
pl.subplot(0,1)
pl.add_mesh(mesh_d, scalars=scalar,
            clim=clim_dict[scalar],
            # show_scalar_bar=True
            # show_edges=True
            )
pl.camera_position='zx'
pl.show_axes()
# pl.add_scalar_bar( )
pl.add_title('Decoded')

pl.link_views()
pl.show()


Widget(value='<iframe src="http://localhost:33601/index.html?ui=P_0x7317efea2990_106&reconnect=auto" class="py…