# Particle Autoencoder Analysis Example

In [None]:
import torch
import os, math, time
import numpy as np
from sklearn.cluster import KMeans
from matplotlib import pyplot as plt

from train import inference_latent, reconstruction
from process_data import collect_file, data_reader, PointData, vtk_write, numpy_to_vtp, collate_ball
from mean_shift import mean_shift_track
from model.GeoConvNet import GeoConvNet
from torch.utils.data import DataLoader

try:
    data_path = os.environ['data']
except KeyError:
    data_path = './data/'

## Load trained model and prepare the data

In [None]:
load_filename = './example/final_model.pth'
use_cuda = torch.cuda.is_available()
state_dict = torch.load(load_filename,map_location='cuda' if use_cuda else 'cpu')
state = state_dict['state']
# load model related arguments
config = state_dict['config']
args = config
device = torch.device('cuda') if use_cuda else torch.device('cpu')
print(args)

    
if args.source == "fpm":
    file_list = collect_file(os.path.join(data_path, "2016_scivis_fpm/0.44/run41"),args.source,shuffle=False)
    fileName = os.path.join(data_path,"2016_scivis_fpm/0.44/run41/025.vtu")
    # fileName = os.path.join(data_path,"/2016_scivis_fpm/0.44/run41/035.vtu")
    input_dim = 7
elif args.source == "cos":
    file_list = collect_file(os.path.join(data_path,"ds14_scivis_0128/raw"),args.source,shuffle=False)
    fileName = os.path.join(data_path,'ds14_scivis_0128/raw/ds14_scivis_0128_e4_dt04_0.4900')
    input_dim = 10
elif args.source == 'jet3b':
    file_list = collect_file(os.path.join(data_path,"jet3b/run3g_50Am_jet3b_sph.3400"),args.source,shuffle=False)
    fileName = os.path.join(data_path,"jet3b/run3g_50Am_jet3b_sph.3400")
    input_dim = 5
print(fileName)
model = GeoConvNet(args.lat_dim, input_dim, args.ball, args.enc_out, args.r).float().to(device)
model.load_state_dict(state)
model.eval()
torch.set_grad_enabled(False)

data_source = data_reader(fileName, args.source)
pd = PointData(data_source, args.k, args.r, args.ball, np.arange(len(data_source)))

kwargs = {'pin_memory': True} if use_cuda else {}
batch_size = 1024
loader = DataLoader(pd, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=collate_ball if args.ball else None, **kwargs)

## Plot training loss

In [None]:
with open(os.path.join(args.result_dir,'epoch_loss_log.txt'), 'r') as f:
    lines = f.readlines()
    x = [float(line) for line in lines]
fig=plt.figure(figsize=(10,6), dpi= 100, facecolor='w', edgecolor='k')
plt.plot(np.arange(1,len(x)+1),x)
plt.show()

## Calculate PSNR

In [None]:
output = reconstruction(model,loader,input_dim,args.ball,device)
output = output.numpy()
mse = np.mean((output - data_source[:,3:]) ** 2)
psnr = 10 * math.log10(0.64/mse)
print('psnr:',psnr)

# Inference latent vectors

In [None]:
# inference latent vectors
latent = inference_latent(model,loader,args.lat_dim,args.ball,device)
print(latent.shape)
np.save(os.path.join(args.result_dir,'latent.npy'),latent.numpy())

### Cluster the latent vectors and save to vtk format

In [None]:
# clustering
km = KMeans(8,n_init=10,n_jobs=8)
clu = km.fit_predict(latent)
print(clu.shape)
# write npy output to vti format
coord = pd.data[:,:3]
if args.source == 'jet3b':
    data_dict = {
        'pred_rho': output[:,0],
        'pred_temp': output[:,1],
        'rho': pd.data[:,3],
        'temp': pd.data[:,4],
        'clu': clu
    }
elif args.source == 'fpm':
    data_dict = {
        'pred_concentration': output[:,0],
        'pred_velocity': output[:,1:],
        'concentration': pd.data[:,3],
        'velocity': pd.data[:,4:],
        'clu': clu
    }
elif args.source == 'cos':
    data_dict = {
        "pred_phi":output[:,-1],
        "pred_velocity":output[:,0:3],
        "pred_acceleration":output[:,3:6],
        "phi":pd.data[:,-1],
        "velocity":pd.data[:,3:6],
        "acceleration":pd.data[:,6:9],
        'clu': clu
    }
vtk_data = numpy_to_vtp(coord,data_dict)
vtk_write( vtk_data, args.result_dir + "/predict.vtp")