In [2]:
import os
import sys
from math import pi, sqrt
import numpy as np
from numpy.random import rand
import time
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad
from torch.autograd import Variable
from random import random, normalvariate
from sklearn.neighbors import NearestNeighbors

from scipy.interpolate import RegularGridInterpolator

from model.utility import generate_data, histogramcnn, histtopdf,make_grid, cal_grid_data
from model.LLES import kernel

from copy import copy
torch.set_default_dtype(torch.float64)

# Train the smoothing kernel first before training L-LES  model

# Setup utility function

In [2]:
def creat_mesh(n,l):
    x = np.linspace(0,2*np.pi,n+1)
    dx = x[1]-x[0]
    x_pre =x[:l] -x[l]-dx
    x_post = x[:l]+x[-1]+dx
    x = np.concatenate([x_pre,x,x_post],axis=0)
    y=x
    z=x
    return x,y,z
def create_periodicity(u,l):
    n = u.shape[0]
    u2 = np.zeros([n+2*l+1,n+2*l+1,n+2*l+1])
    u2[l:n+l,l:n+l,l:n+l] = u

    u2[l:n+l,l:n+l,-l-1:] = u2[l:n+l,l:n+l,l:l+l+1]
    u2[l:n+l,l:n+l,:l] = u2[l:n+l,l:n+l,-l-l-1:-l-1]

    u2[l:n+l,-l-1:,:] = u2[l:n+l,l:l+l+1,:]
    u2[l:n+l,:l,:] = u2[l:n+l,-l-l-1:-l-1,:]

    u2[-l-1:,:,:] = u2[l:l+l+1,:,:]
    u2[:l,:,:] = u2[-l-l-1:-l-1,:,:]
    return u2
def cal_w(n):
    W = np.zeros(n)
    for i in range(n):
        if i <= int(n/2):
            W[i] = i
        else:
            W[i] = i-n
    return W
def generate_data():
    ngridx = 64
    Wx = cal_w(ngridx)
    Wy = cal_w(ngridx)
    Wz = cal_w(ngridx)[:int(ngridx/2)+1]

    Wx2 = Wx*Wx
    Wy2 = Wy*Wy
    Wz2 = Wz*Wz
    kmax = ngridx/2*(2*np.sqrt(2)/3)
    tke = 1e4
    rhorms = 1e3
    pos = np.random.rand(ngridx**3,3)*2*np.pi
    x,y,z = creat_mesh(ngridx,1)
    
    iu = np.zeros((ngridx, ngridx, int(ngridx/2)+1),dtype = np.complex128)
    iv = np.zeros((ngridx, ngridx, int(ngridx/2)+1),dtype = np.complex128)
    iw = np.zeros((ngridx, ngridx, int(ngridx/2)+1),dtype = np.complex128)
    
    irho = np.zeros((ngridx, ngridx, int(ngridx/2)+1),dtype = np.complex128)

    for i in range(ngridx):
        for j in range(ngridx):
            for k in range(int(ngridx/2)+1):
                W2 = (Wx2[i] + Wy2[j] + Wz2[k])**0.5
                if W2 < kmax and W2 != 0:
                    iu[i,j,k] = tke*complex(np.random.randn(1),np.random.randn(1))
                    iu[i,j,k] = iu[i,j,k]/W2**2
                    iv[i,j,k] = tke*complex(np.random.randn(1),np.random.randn(1))
                    iv[i,j,k] = iv[i,j,k]/W2**2
                    iw[i,j,k] = tke*complex(np.random.randn(1),np.random.randn(1))
                    iw[i,j,k] = iw[i,j,k]/W2**2
                    irho[i,j,k] = rhorms*complex(np.random.randn(1),np.random.randn(1))
                    irho[i,j,k] = irho[i,j,k]/W2**2                    
                    
    u = np.fft.irfftn(iu)
    v = np.fft.irfftn(iu)
    w = np.fft.irfftn(iw)
    rho = np.fft.irfftn(irho)
    
    u = create_periodicity(u,1)
    v = create_periodicity(v,1)
    w = create_periodicity(w,1)
    rho = create_periodicity(rho,1)
    
    interp_u = RegularGridInterpolator((x, y, z), u)
    interp_v = RegularGridInterpolator((x, y, z), v)
    interp_w = RegularGridInterpolator((x, y, z), w)
    interp_rho = RegularGridInterpolator((x, y, z), rho)
    
    par_u = interp_u(pos)
    par_v = interp_v(pos)
    par_w = interp_w(pos)
    par_rho = interp_rho(pos)

    vel = np.stack([par_u, par_v, par_w], axis=1)
    
    pos_next = pos + vel*0.05
    
    par_unext = interp_u(pos_next)
    par_vnext = interp_v(pos_next)
    par_wnext = interp_w(pos_next)
    par_rhonext = interp_rho(pos_next)
    vel_next = np.stack([par_unext, par_vnext, par_wnext], axis=1)
    
    pos_traj = np.stack([pos, pos_next])
    vel_traj = np.stack([vel, vel_next])
    rho_traj = np.stack([par_rho, par_rhonext])
    return torch.from_numpy(pos_traj).to(device), torch.from_numpy(vel_traj).to(device), torch.from_numpy(rho_traj.reshape([2,-1,1])).to(device)

In [3]:
class histogramcnn(nn.Module):
    def __init__(self,maxx):
        super(histogramcnn,self).__init__()
        self.outc = 30
        self.dx = maxx*1.05/self.outc
        self.cnst = torch.tensor(self.outc)
        self.l1 = nn.Conv1d(in_channels=1,out_channels=self.outc*2+1, kernel_size = 1)
        self.act = nn.ReLU()
    def set_param(self,device):
        with torch.no_grad():
            self.l1.weight.fill_(1.0)
            self.l1.bias.data=(-torch.tensor(np.linspace(-self.outc,self.outc,self.outc*2+1,dtype=np.float64)).to(device)*self.dx)
    def forward(self,X):
        out = self.act(X+self.outc*self.dx) - self.outc*self.dx
        out = -self.act(-out+self.outc*self.dx) + self.outc*self.dx
        out = self.l1(X)
        out = out.abs()
        out = self.act(out*(-1.0/self.dx)+1.0)
        out = out.sum(axis=2)
        return out
def histtopdf(hist, data1,data2):
    pdf1 = (hist(data1.reshape([1,1,-1]))+1.0e-10)/data1.shape[0]
    pdf2 = (hist(data2.reshape([1,1,-1]))+1.0e-10)/data2.shape[0]
    kl = (pdf1*(pdf1/pdf2).log()).sum() + (pdf2*(pdf2/pdf1).log()).sum()
    return kl


In [4]:
def make_grid(nfield):
    x = np.linspace(0,2*np.pi,nfield+1)
    y = x
    z = x
    xx,yy,zz = np.meshgrid(x[:nfield],y[:nfield],z[:nfield],indexing='ij')
    grid = np.stack([xx,yy,zz]).transpose([1,2,3,0])
    grid = grid.reshape([-1,3])
    return grid

In [5]:
def cal_grid_data(pos_traj,vel_traj,rho_traj,grid,kernel,neighbor_train):
    print('Generating GT field data')
    n_learn = pos_traj.shape[0]
    pos_tmp = pos_traj.cpu().detach().numpy()
    vel_tmp = vel_traj.cpu().detach().numpy()
    rho_tmp = rho_traj.cpu().detach().numpy()
    h = kernel.h.cpu().detach().numpy()

    traj_gt_period = []
    for i in range(3):
        for j in range(3):
            for k in range(3):
                a = np.ones(pos_tmp.shape)*np.array([2*np.pi*(i-1),2*np.pi*(j-1),2*np.pi*(k-1)])
                b = pos_tmp
                traj_gt_period.append(a+b)
    traj_gt_period=np.stack(traj_gt_period)
    traj_gt_period = traj_gt_period.reshape([-1,3])

    nbrs = NearestNeighbors(n_neighbors=neighbor_train, algorithm='ball_tree').fit(traj_gt_period)

    distances, neighbor_new = nbrs.kneighbors(grid)
    pdistances, pneighbor_new = nbrs.kneighbors(pos_tmp)

    neighbor = (np.remainder(neighbor_new,n_learn))

    rho = np.zeros([n_learn])

    for n in range(neighbor_train):
        dis = torch.tensor(pdistances[:,n]/h)
        with torch.no_grad():
            wxyz = kernel.wnn_nn(dis)
        rho = rho + wxyz.detach().numpy().reshape([-1])

    grid_field = np.zeros([grid.shape[0],4])
    w_field = np.zeros([grid.shape[0],neighbor_train])
    for n in range(neighbor_train):
        dis = torch.tensor(distances[:,n]/h)
        with torch.no_grad():
            wxyz = kernel.wnn_nn(dis)
        w_field[:,n] = wxyz.detach().numpy()/rho[neighbor[:,n]]
        grid_field[:,0] = grid_field[:,0] + rho_tmp[neighbor[:,n],0]*w_field[:,n]
        grid_field[:,1] = grid_field[:,1] + vel_tmp[neighbor[:,n],0]*w_field[:,n]
        grid_field[:,2] = grid_field[:,2] + vel_tmp[neighbor[:,n],1]*w_field[:,n]
        grid_field[:,3] = grid_field[:,3] + vel_tmp[neighbor[:,n],2]*w_field[:,n]


    return grid_field,w_field,rho, neighbor


# Setup LLES and PIML-SK model

In [3]:
class kernel(nn.Module):
    def __init__(self, N,nfield,neighbor_train,device):
        super(kernel, self).__init__()
        self.N = N
        self.D = torch.tensor(3)
        self.device =device
        self.pi = torch.tensor(3.14159265358).to(self.device)

        self.h = nn.Parameter(torch.tensor((((np.pi*2)**3/N*neighbor_train)/np.pi/(4/3))**(1/3)),requires_grad=True)
        self.alpha = nn.Parameter(torch.tensor(100.0),requires_grad=True)

        self.nfield = nfield

        self.lneighbor = neighbor_train
        self.neighbor = torch.zeros([self.N,self.lneighbor],dtype=torch.long).to(self.device)
        self.fneighbor = torch.zeros([self.nfield,self.lneighbor],dtype=torch.long).to(self.device)

        self.qp = torch.linspace(0,1,101).to(self.device)
        self.l1 = nn.Linear(1,20)
        self.l2 = nn.Linear(20,100)
        self.l3 = nn.Linear(100,20)
        self.l4 = nn.Linear(20,1)
        self.act = nn.Tanh()
    def wnn_nn(self,r):
        out = self.l1(r.reshape([-1,1]))
        out = self.act(out)
        out = self.l2(out)
        out = self.act(out)
        out = self.l3(out)
        out = self.act(out)
        out = self.l4(out)
        out = out*torch.sigmoid(10*(1-r)).reshape([-1,1])*self.alpha
        return out.reshape([-1])

    def wnn_r(self,r):
        dr = 0.00001
        x1 = r.reshape([-1,1]) + dr
        x2 = r.reshape([-1,1]) - dr
        y1 = self.wnn_nn(x1)
        y2 = self.wnn_nn(x2)
        return (y1-y2)/(2*dr)/self.h

    def wnn_r_grad(self,r):
        out = vjp(self.wnn_nn, r.reshape([-1,1]), torch.ones(r.shape[0]).to(self.device),create_graph=True )[1]

        return out/self.h
    def wnn_drr(self):
        out = vjp(self.wnn_r_grad,torch.tensor(0.0).reshape([1]).to(self.device), torch.ones(1,1).to(self.device),create_graph=True )[1]

        return out.flatten()/self.h

    def cal_integral(self):
        dh = (self.qp[1]-self.qp[0])*self.h
        y = self.wnn_nn(self.qp).reshape([-1])
        surface = 4.0*self.pi*(self.qp*self.h).pow(2)
        y = y*surface
        return 0.5*dh*(y[0]+y[-1]+2.0*(y[1:-1].sum()))
    
    
    def cal_disv(self,X,Xfield,i,batch):
        temp1 = torch.abs(Xfield[batch]-X[self.fneighbor[batch,i]])
        temp1_1 = -torch.sign(Xfield[batch]-X[self.fneighbor[batch,i]])*torch.sign(Xfield[batch]-X[self.fneighbor[batch,i]]+torch.ones(temp1.shape).to(self.device)*\
self.pi)*torch.sign(Xfield[batch]-X[self.fneighbor[batch,i]]-torch.ones(temp1.shape).to(self.device)*self.pi)
        temp2 = torch.ones(temp1.shape).to(self.device)*self.pi*2.0-temp1
        out = temp1_1*torch.min(torch.stack([temp1,temp2],axis=2),axis=2)[0]
        out2 = torch.sum(out*out,axis=1).reshape([-1,1])
        return torch.sqrt(out2)/self.h, out/torch.sqrt(out2)

    def cal_dis(self,X,i,batch):
        temp1 = torch.unsqueeze(torch.abs(X[batch]-X[self.neighbor[batch,i]]),2)
        temp2 = torch.ones(temp1.shape).to(self.device)*self.pi*2.0-temp1
        out2 = torch.cat((temp1,temp2),axis=2)
        out2 = torch.min(torch.stack([temp1,temp2],axis=2),axis=2)[0]
        return torch.sqrt(torch.sum(out2*out2,axis=1))/self.h


    def cal_dis_field(self,X,Xfield,i,batch):
        temp1 = torch.unsqueeze(torch.abs(Xfield[batch]-X[self.fneighbor[batch,i]]),2)
        temp2 = torch.ones(temp1.shape).to(self.device)*self.pi*2.0-temp1
        out2 = torch.cat((temp1,temp2),axis=2)
        out2 = torch.min(torch.stack([temp1,temp2],axis=2),axis=2)[0]
        return torch.sqrt(torch.sum(out2*out2,axis=1))/self.h
    
    
    def cal_rho_nn(self,X,batch):
        rho = self.wnn_nn(torch.zeros([batch.shape[0]]).to(self.device))
        for i in range(self.lneighbor):
            dis = self.cal_dis(X,i,batch)
            rho = rho+self.wnn_nn(dis)
        return rho

    def cal_rho_nn_field(self,X,Xfield,batch):
        rho = self.wnn_nn(torch.zeros([batch.shape[0]]).to(self.device))
        for i in range(self.lneighbor):
            dis = self.cal_dis_field(X,Xfield,i,batch)
            rho = rho+self.wnn_nn(dis)
        return rho
    
    
    def cal_f_nn(self,X,Xfield,f,ffield,batch):
        rho_f = self.cal_rho_nn_field(X,Xfield,batch).reshape([-1,1])
        rho = torch.zeros([batch.shape[0],f.shape[-1] ]).to(self.device)
        drhodx = torch.zeros([batch.shape[0],self.D ]).to(self.device)

        for i in range(self.lneighbor):
            rho_p = self.cal_rho_nn(X,self.fneighbor[batch,i])
            dis,disv = self.cal_disv(X,Xfield,i,batch)
            rho = rho+f[self.fneighbor[batch,i]]*self.wnn_nn(dis).reshape([-1,1])/rho_p.reshape([-1,1])
            dwdr = self.wnn_r(dis).reshape([-1,1])
            drhodx = drhodx + (f[self.fneighbor[batch,i],1]-ffield[batch]).reshape([-1,1])*disv*dwdr
        return rho, drhodx/rho_f
    def update_neighborlist_sklearn(self,X,Xfield):
        traj_gt_period = []
        traj_copy = X.clone().cpu()
        traj_eu = Xfield.clone().cpu()
        for i in range(3):
            for j in range(3):
                for k in range(3):
                    traj_gt_period.append(traj_copy.cpu()+np.ones(traj_copy.shape)*np.array([2*np.pi*(i-1),2*np.pi*(j-1),2*np.pi*(k-1)]))
        traj_gt_period=np.stack(traj_gt_period)
        traj_gt_period = traj_gt_period.reshape([-1,self.D])
        nbrs = NearestNeighbors(n_neighbors=self.lneighbor, algorithm='ball_tree').fit(traj_gt_period)
        distances, neighbor_new = nbrs.kneighbors(traj_copy)
        neighbor_new= torch.from_numpy(neighbor_new).to(self.device)
        self.neighbor = (torch.remainder(neighbor_new,self.N)).clone().to(self.device)

        distances, neighbor_new = nbrs.kneighbors(traj_eu)
        neighbor_new= torch.from_numpy(neighbor_new).to(self.device)
        self.fneighbor = (torch.remainder(neighbor_new,self.N)).clone().to(self.device)


In [4]:
class LLES(nn.Module):
    def __init__(self,device, N,dt, vref, tref,rhoref, nfield, kernel_wnn ,neighbor_train):
        super(LLES, self).__init__()
        self.N = N  
        self.D = 3
        self.nfeat = 5
        self.nfeat_out_vel  = 2
        self.nfeat_out_rho  = 1

        self.nfield = nfield
        self.kernel = copy(kernel_wnn)
        self.device = device
        self.dt = dt

        self.pi = torch.tensor(3.14159265358).to(self.device)
        self.lneighbor = neighbor_train
        self.neighbor = torch.zeros([self.N,self.lneighbor],dtype=torch.long).to(self.device)
        self.fneighbor = torch.zeros([self.nfield,self.lneighbor],dtype=torch.long).to(self.device)
        self.h=torch.tensor((((np.pi*2)**3/N*neighbor_train)/np.pi/(4/3))**(1/3)).to(device)
        
        self.alpha1 = nn.Parameter(torch.tensor(0.1),requires_grad=True)
        self.alpha2 = nn.Parameter(torch.tensor(0.1),requires_grad=True)

        self.beta1 = nn.Parameter(torch.tensor(0.1),requires_grad=True)
        self.beta2 = nn.Parameter(torch.tensor(0.1),requires_grad=True)

        self.vref = vref
        self.tref = tref
        self.aref = self.vref/self.tref
        self.rhoref = rhoref
        self.drhoref = self.rhoref/self.tref
        
        self.l1 = nn.Linear(self.nfeat, 20)
        self.l2 = nn.Linear(20,100)
        self.l3 = nn.Linear(100,20)
        self.l4 = nn.Linear(20,self.nfeat_out_vel)

        self.l1_rho = nn.Linear(self.nfeat, 20)
        self.l2_rho = nn.Linear(20,100)
        self.l3_rho = nn.Linear(100,20)
        self.l4_rho = nn.Linear(20,self.nfeat_out_rho)

        self.act = nn.Tanh()

    def wnn(self,r):
        with torch.no_grad():
            out = self.kernel.wnn_nn(r.reshape([-1,1]))
        return out

    def knn_nn(self,r):
        out = self.l1(r.reshape([-1,self.nfeat]))
        out = self.act(out)
        out = self.l2(out)
        out = self.act(out)
        out = self.l3(out)
        out = self.act(out)
        out = self.l4(out)

        out_rho = self.l1_rho(r.reshape([-1,self.nfeat]))
        out_rho = self.act(out_rho)
        out_rho = self.l2_rho(out_rho)
        out_rho = self.act(out_rho)
        out_rho = self.l3_rho(out_rho)
        out_rho = self.act(out_rho)
        out_rho = self.l4_rho(out_rho)

        return out, out_rho
    
    def cal_dis(self,X,i,batch):
        temp1 = torch.unsqueeze(torch.abs(X[batch]-X[self.neighbor[batch,i+1]]),2)
        temp2 = torch.ones(temp1.shape).to(self.device)*self.pi*2.0-temp1
        out2 = torch.cat((temp1,temp2),axis=2)
        out2 = torch.min(torch.stack([temp1,temp2],axis=2),axis=2)[0]
        return torch.sqrt(torch.sum(out2*out2,axis=1))/self.kernel.h
    def cal_disv(self,X,V,rho,i,batch):
        temp1 = torch.abs(X[batch]-X[self.neighbor[batch,i+1]])
        temp1_1 = -torch.sign(X[batch]-X[self.neighbor[batch,i+1]])*torch.sign(X[batch]-X[self.neighbor[batch,i+1]]+torch.ones(temp1.shape).to(self.device)*self.pi)\
*torch.sign(X[batch]-X[self.neighbor[batch,i+1]]-torch.ones(temp1.shape).to(self.device)*self.pi)
        temp2 = torch.ones(temp1.shape).to(self.device)*self.pi*2.0-temp1
        out = temp1_1*torch.min(torch.stack([temp1,temp2],axis=2),axis=2)[0]
        outv = (V[batch]-V[self.neighbor[batch,i+1]]).reshape([-1,self.D])
        out = out/self.h
        outv = outv/self.vref
        outc = torch.cross(out,outv,axis=1)
        out2 = torch.sum(out*out,axis=1).reshape([-1,1])
        outv2 = torch.sum(outv*outv,axis=1).reshape([-1,1])
        out2v = torch.sum(out*outv,axis=1).reshape([-1,1])
        drho1 = (rho[batch]).reshape([-1,1])/self.rhoref
        drho2 = (rho[self.neighbor[batch,i+1]]).reshape([-1,1])/self.rhoref

        return torch.cat([drho1,drho2, torch.sqrt(out2),torch.sqrt(outv2),out2v],axis=1), torch.cat([(drho1-drho2)/(drho1-drho2).abs(), out/torch.sqrt(out2),outv/torch.sqrt(outv2)],axis=1)

    def cal_rho_nn(self,X,batch):
        prho = torch.zeros([batch.shape[0],1]).to(self.device)
        for i in range(self.lneighbor):
            dis = self.cal_dis(X,i-1,batch)
            w = self.wnn(dis)
            prho[:] = prho[:] + w
        return prho

    def cal_a_nn(self,X,V,rho,batch):
        drho = torch.zeros([batch.shape[0],1+self.D]).to(self.device)
        for i in range(1,self.lneighbor):
            feature,dis = self.cal_disv(X,V,rho,i-1,batch)
            knn_out, knn_out_rho = self.knn_nn(feature)
            drho[:,0] = drho[:,0] + (self.drhoref)*knn_out_rho[:,0]*dis[:,0]
            drho[:,1:] = drho[:,1:] + self.aref*(knn_out[:,:].reshape([-1,2,1])*dis[:,1:].reshape([-1,2,self.D])).sum(axis=1)
            av = self.cal_av(feature[:,0]-feature[:,1],feature[:,2], feature[:,4],dis[:,1:].reshape([-1,2,self.D])[:,0] )
            drho[:,0] = drho[:,0] + self.drhoref*av[:,0]
            drho[:,1:] = drho[:,1:]  + self.aref*av[:,1:]
        return drho
    def cal_field(self,X,V,rho,batch,w_field):
        field = torch.zeros([batch.shape[0],1+self.D]).to(self.device)
        for i in range(self.lneighbor):
            accl = self.cal_a_nn(X,V,rho,self.fneighbor[batch,i])
            field[:,:] = field[:,:] + w_field[batch,i].reshape([-1,1])*(V[self.fneighbor[batch,i]]+accl*self.dt*0.1)
        return field

    def cal_field_kl(self,X,V,rho,batch,w_field):
        field = torch.zeros([batch.shape[0],1+self.D]).to(self.device)
        accl_t = []
        accl_gt = []

        rho_t = []
        rho_gt = []

        for i in range(self.lneighbor):
            accl = self.cal_a_nn(X,V[0],rho[0],self.fneighbor[batch,i])
            field[:,1:] = field[:,1:] + w_field[batch,i].reshape([-1,1])*(V[0,self.fneighbor[batch,i]]+accl[:,1:]*self.dt)
            field[:,0] = field[:,0] + w_field[batch,i].reshape([-1])*(rho[0,self.fneighbor[batch,i],0]+accl[:,0]*self.dt)

            accl_t.append(accl[:,1:])
            accl_gt.append((V[1,self.fneighbor[batch,i]] - V[0,self.fneighbor[batch,i]])/self.dt)
            rho_t.append(accl[:,0])
            rho_gt.append((rho[1,self.fneighbor[batch,i]] - rho[0,self.fneighbor[batch,i]])/self.dt)
        return field,torch.stack(accl_t).flatten(),torch.stack(accl_gt).flatten(),torch.stack(rho_t).flatten(), torch.stack(rho_gt).flatten()
    def cal_av(self, drho,xx,xv, vec):
        out_rho = drho*self.h**2/(xx**2 + 0.1*self.h**2)
        out_rho = -(torch.abs(self.beta1)+torch.abs(self.beta2)*torch.abs(out_rho))*out_rho
        out_rho = out_rho.reshape([-1,1])

        out = -1.0*self.h*self.act(-1.0*xv)/((xx)**2+0.1*self.h**2)
        out = -1.0*torch.abs(self.alpha1)*out + torch.abs(self.alpha2)*out**2
        out = out.reshape([-1,1])*vec
        return torch.cat([out_rho,out],axis=1)

    def update_neighborlist_sklearn(self,X,fneighbor):
        traj_gt_period = []
        traj_copy = X.clone().cpu()
        for i in range(3):
            for j in range(3):
                for k in range(3):
                    traj_gt_period.append(traj_copy.cpu()+np.ones(traj_copy.shape)*np.array([2*np.pi*(i-1),2*np.pi*(j-1),2*np.pi*(k-1)]))
        traj_gt_period=np.stack(traj_gt_period)
        traj_gt_period = traj_gt_period.reshape([-1,self.D])
        nbrs = NearestNeighbors(n_neighbors=self.lneighbor, algorithm='ball_tree').fit(traj_gt_period)
        distances, neighbor_new = nbrs.kneighbors(traj_copy)
        neighbor_new= torch.from_numpy(neighbor_new).to(self.device)
        self.neighbor = (torch.remainder(neighbor_new,self.N)).clone().to(self.device)

        self.fneighbor = torch.from_numpy(fneighbor).to(self.device)
        return traj_gt_period
    def cal_forcing(self,X,V,batch):
        force = torch.zeros([batch.shape[0],self.D]).to(self.device)
        for i in range(self.lneighbor):
            dis = self.cal_dis(X,i-1,batch)
#            w = self.wnn(dis)                                                     
            force = force + V[self.neighbor[batch,i]]/self.lneighbor
        return force


In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load Lagrangian data

Load the trajectory of the Lagrangian particle to the variable pos_traj(position), vel_traj(velocity), rho_traj(density)

They should have the shape of (ntime, nparticle, 3) for velocity and position or (ntime, nparticle, 1) for density. 

In this notebook, we use ntime =2, nparticle = 262144

In [7]:
#Here we generate random data for demonstration
pos_traj, vel_traj, rho_traj = generate_data(device)

ValueError: too many values to unpack (expected 3)

In [11]:
#Obtain the reference scale oof the velocity and density
vref = (vel_traj.var())**0.5
rhoref = (rho_traj.var())**0.5

In [13]:
N = 262144  ## 64**3
dt = 0.05
tref = torch.tensor(0.1) # Should the be reference timescale at the filtered scale
nfield = 32
nfieldt = nfield**3
neighbor_kernel = 60


In [14]:
#Load the pre-trained kernel function
kernel_wnn = kernel(N, nfield, neighbor_kernel, device)
kernel_wnn.load_state_dict(torch.load("SmoothingKernel.params"))

<All keys matched successfully>

In [15]:
model  = LLES(device, N,dt, vref, tref,rhoref, nfieldt, kernel_wnn ,neighbor_kernel)

In [16]:
traj_train_label_vel = ((vel_traj[1:]-vel_traj[:-1])/dt)
traj_train_label_rho = ((rho_traj[1:]-rho_traj[:-1])/dt)

# Setup Statistics-based Loss function

Build kerenl function that maps samples to histogram.

In [17]:
#Velocity histogram
accl_max = ((vel_traj[1:]-vel_traj[:-1])/dt).abs().max()
datatohist = histogramcnn(accl_max)
datatohist.to(device)
datatohist.set_param(device)

#Density histogram
rho_max = ((rho_traj[1:]-rho_traj[:-1])/dt).abs().max()
datatohist_rho = histogramcnn(rho_max)
datatohist_rho.to(device)
datatohist_rho.set_param(device)

# Setup Eulerian-grid based Loss function

In [18]:
# Generate field data for training
grid = make_grid(nfield)
grid_field_ref, w_field, p_rho, fneighbor=cal_grid_data(pos_traj[0],vel_traj[0],rho_traj[0],grid,kernel_wnn,neighbor_kernel)
field_train_label = torch.tensor(grid_field_ref).to(device)
w_field = torch.from_numpy(w_field).to(device)

Generating GT field data


In [20]:
# Compute the neighbors list of particles and grid points
out = model.update_neighborlist_sklearn(pos_traj[0],fneighbor)

In [40]:
# Configure the coeffcients of different loss function and optimizer
alpha_field = 1.0
alpha_kl = 0.1
alpha_traj = 1.0
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)

In [22]:
# Setup minibatch for field-based and trajectory-based loss function
batch_size_f =32*8
batch_number_f = int(nfieldt/batch_size_f)
batch_array_f = np.arange(nfieldt)
np.random.shuffle(batch_array_f)
batch_array_f=torch.tensor(batch_array_f.reshape([batch_number_f,batch_size_f]))
#-------traj batch ------------                                                
batch_size_t =128*32
batch_number_t = int(N/batch_size_t)
batch_array_t = np.arange(N)
np.random.shuffle(batch_array_t)
batch_array_t=torch.tensor(batch_array_t.reshape([batch_number_t,batch_size_t]))


# First train with Trajectory-based and Statistics-based loss

In [23]:
train_epochs = 100
loss_traj = []

In [34]:
for i in range(train_epochs):
    for nb in range(batch_number_t):
        optimizer.zero_grad()
        accl = model.cal_a_nn(pos_traj[0],vel_traj[0],rho_traj[0],batch_array_t[nb])
    
        # Trajectory-based loss function
        accl_gt = traj_train_label_vel[0,batch_array_t[nb]]
        drho_gt = traj_train_label_rho[0,batch_array_t[nb],0]
        
        traj_l2loss_vel = (accl[:,1:]-accl_gt).pow(2).mean()
        traj_l2loss_rho = (accl[:,0]-drho_gt).pow(2).mean()
        traj_l2loss = traj_l2loss_vel + traj_l2loss_rho
    
        # Statistics-based loss function
        kl = histtopdf(datatohist,accl_gt,accl[:,1:])
        kl_rho = histtopdf(datatohist_rho, drho_gt, accl[:,0])
        kl_loss = kl + kl_rho
        loss_stat = kl_loss

        loss = alpha_traj*traj_l2loss +  alpha_kl*loss_stat

        loss.backward()#retain_graph=True)
        optimizer.step()
        loss_traj.append(0)
        loss_traj.append(traj_l2loss.cpu().detach().numpy())
        loss_traj.append(kl_loss.cpu().detach().numpy())
    print('Epoch = {}, Total loss = {}'.format(i, loss.cpu().detach().numpy()))    
    print('Statistics loss = {}'.format(loss_stat.cpu().detach().numpy()))
    print('Traj loss for velocity= {}'.format(traj_l2loss_vel.cpu().detach().numpy()))
    scheduler.step()

Epoch = 0, Total loss = 0.07674251245070063
Statistics loss = 8.827665312457217
Traj loss for velocity= 0.06454547568574562
Epoch = 1, Total loss = 0.07144572852751603
Statistics loss = 7.720423754694939
Traj loss for velocity= 0.06127346605309656
Epoch = 2, Total loss = 0.06752489586516575
Statistics loss = 6.907704308038942
Traj loss for velocity= 0.0587179037294923
Epoch = 3, Total loss = 0.06453838334653735
Statistics loss = 6.241313788386557
Traj loss for velocity= 0.05665349279660355
Epoch = 4, Total loss = 0.06218550430453057
Statistics loss = 5.9384941274656935
Traj loss for velocity= 0.05493379126362158


KeyboardInterrupt: 

# Then train with Field-based loss

In [39]:
for i in range(train_epochs):
    for nb in range(batch_number_f):
        optimizer.zero_grad()
        pred,accl,accl_gt,drho,drho_gt = model.cal_field_kl(pos_traj[0],vel_traj[:2],rho_traj[:2],batch_array_f[nb],w_field)
        
        l2loss = (pred-field_train_label[batch_array_f[nb]]).pow(2).mean()

        kl = histtopdf(datatohist,accl_gt,accl)
        kl_rho = histtopdf(datatohist_rho, drho_gt, drho)
        kl_loss = kl + kl_rho

        traj_l2loss_vel = (accl-accl_gt).pow(2).mean()
        traj_l2loss_rho = (drho-drho_gt).pow(2).mean()
        traj_l2loss = traj_l2loss_vel #+  traj_l2loss_rho  
        
        loss_stat = kl_loss

        loss = alpha_traj*traj_l2loss +  alpha_kl*loss_stat + alpha_field* l2loss

        loss.backward(retain_graph=True)
        optimizer.step()
        loss_traj.append(0)
        loss_traj.append(traj_l2loss.cpu().detach().numpy())
        loss_traj.append(kl_loss.cpu().detach().numpy())
        
        print('Epoch = {}, Total loss = {}'.format(i, loss.cpu().detach().numpy()))    
        print('Statistics loss = {}'.format(loss_stat.cpu().detach().numpy()))
        print('Traj loss for velocity= {}'.format(traj_l2loss_vel.cpu().detach().numpy()))
        scheduler.step()

Epoch = 0, Total loss = 0.05192473135253242
Statistics loss = 4.967192802207148
Traj loss for velocity= 0.05191997270115471
Epoch = 0, Total loss = 0.050351175703965426
Statistics loss = 4.056871058384869
Traj loss for velocity= 0.05034658165625208
Epoch = 0, Total loss = 0.05577382324099858
Statistics loss = 4.003624417161871
Traj loss for velocity= 0.05576898738437649
Epoch = 0, Total loss = 0.05170432526189021
Statistics loss = 5.356549763414385
Traj loss for velocity= 0.05169962898348966
Epoch = 0, Total loss = 0.05548409118718009
Statistics loss = 4.166721016617365
Traj loss for velocity= 0.05547955624270439
Epoch = 0, Total loss = 0.05538879823986084
Statistics loss = 3.0232281240555423
Traj loss for velocity= 0.055384157120201616
Epoch = 0, Total loss = 0.052030735106188865
Statistics loss = 4.867511873521138
Traj loss for velocity= 0.0520266970353667
Epoch = 0, Total loss = 0.04876524629537842
Statistics loss = 4.005538651189336
Traj loss for velocity= 0.048761037612193275
Epoc

KeyboardInterrupt: 