In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.nn.functional as func
import torch.autograd as autograd
import random
import datetime

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Device name:',device)

Device name: cuda:0


In [2]:
##### Define class for neural network #####

eps = torch.finfo(torch.float32).eps

class PINN(nn.Module):
    def __init__(self, input_num, output_num, receivers, senders, radius, bias=True):
        super().__init__()
        
        self.bias=bias
        self.radius_tg = radius
        self.fc1 = nn.Linear(input_num,32, bias=bias)
        self.fc2 = nn.Linear(32,32, bias=bias)
        self.fc3 = nn.Linear(32,32, bias=bias)
        self.fc4 = nn.Linear(32, output_num, bias=bias)
        self.reset_parameters()
        
        self.receivers=receivers
        self.senders=senders

        self.lambda1 = torch.FloatTensor(1).uniform_(-1, 1).requires_grad_(True).to(device)
        self.lambda2 = torch.FloatTensor(1).uniform_(-1, 1).requires_grad_(True).to(device)
        self.lambda3 = torch.FloatTensor(1).uniform_(-1, 1).requires_grad_(True).to(device)
        self.lambda4 = torch.FloatTensor(1).uniform_(-1, 1).requires_grad_(True).to(device)
        self.power1 = torch.FloatTensor(1).uniform_(-1, 1).requires_grad_(True).to(device)
        self.power2 = torch.FloatTensor(1).uniform_(-1, 1).requires_grad_(True).to(device)
        self.power3 = torch.FloatTensor(1).uniform_(-1, 1).requires_grad_(True).to(device)     
        
        self.lambda1 = nn.Parameter(self.lambda1)
        self.lambda2 = nn.Parameter(self.lambda2)
        self.lambda3 = nn.Parameter(self.lambda3)
        self.lambda4 = nn.Parameter(self.lambda4)
        self.power1 = nn.Parameter(self.power1)
        self.power2 = nn.Parameter(self.power2)
        self.power3 = nn.Parameter(self.power3)

        
    def forward(self, x):
        self.input = x
        self.row_num = x.size(0)
        x_max = x[-1].item()//2
        
        x=(self.input-x_max)/x_max
        self.x1=torch.tanh(self.fc1(x))
        self.x2=torch.tanh(self.fc2(self.x1))
        self.x3=torch.tanh(self.fc3(self.x2))
        self.x4 = self.fc4(self.x3)
        output = self.radius_tg*self.x4
        
        return output

    def reset_parameters(self) -> None:
        nn.init.xavier_uniform_(self.fc1.weight, gain = nn.init.calculate_gain('tanh'))
        nn.init.xavier_uniform_(self.fc2.weight, gain = nn.init.calculate_gain('tanh'))
        nn.init.xavier_uniform_(self.fc3.weight, gain = nn.init.calculate_gain('tanh'))
        nn.init.xavier_uniform_(self.fc4.weight, gain=1)
        
        if self.bias:
            nn.init.constant_(self.fc1.bias, 0.1)
            nn.init.constant_(self.fc2.bias, 0.1)
            nn.init.constant_(self.fc3.bias, 0.1)
            nn.init.constant_(self.fc4.bias, 0.1)
            
    def loss_func(self, pred, r_a=1.0, r_b=4.0, v0=2, tau=10):
        loss_mse = nn.MSELoss().to(device)
        #calculate gradients
        
        half = pred.size(1)//2
        radius = (r_a+r_b)/2
        
        center_x = torch.mean(pred[:,0:half], dim=1, keepdim=True)
        center_y = torch.mean(pred[:,half:], dim=1, keepdim=True)
        
        for i in range(pred.size(1)):
            temp=torch.zeros_like(pred)
            temp[:,i]=1
            grads, = autograd.grad(pred, self.input, grad_outputs=temp, create_graph=True)
            if i==0:
                self.u = grads
            else:
                self.u = torch.hstack((self.u, grads))

        self.u_mag=torch.sqrt(torch.square(self.u[:,0:half])+torch.square(self.u[:,half:]))

        for i in range(pred.size(1)):
            temp=torch.zeros_like(pred)
            temp[:,i]=1
            grads, = autograd.grad(self.u, self.input, grad_outputs=temp, create_graph=True)
            if i==0:
                self.accel = grads
            else:
                self.accel = torch.hstack((self.accel, grads))
        
        N_time=pred.size(0)
        
        
        #calculate loss between output edges and f_inter matgnitude from physics

        pos_x = pred[:,0:half]
        pos_y = pred[:,half:]
        pos_x_r = pos_x[:, self.receivers.flatten().type(torch.int64)]
        pos_y_r = pos_y[:, self.receivers.flatten().type(torch.int64)]
        pos_x_s = pos_x[:, self.senders.flatten().type(torch.int64)]
        pos_y_s = pos_y[:, self.senders.flatten().type(torch.int64)]

        dist_x = pos_x_r - pos_x_s
        dist_y = pos_y_r - pos_y_s
        dist=torch.sqrt(torch.square(dist_x)+torch.square(dist_y))
        dist_mat = dist.view(pred.size(0),-1,half-1)
        
        r_mag = torch.sqrt(torch.square(pos_x) + torch.square(pos_y))
        
        
        #### Calculate interaction forces #####
        
        self.f_inter_x=(dist_x/dist)*(self.lambda1*(dist**self.power1)+self.lambda2*(dist**self.power2)+self.lambda3*(dist**self.power3)\
                                      +self.lambda4)
        self.f_inter_y=(dist_y/dist)*(self.lambda1*(dist**self.power1)+self.lambda2*(dist**self.power2)+self.lambda3*(dist**self.power3)\
                                      +self.lambda4)
        

        f_agg_x = torch.zeros(self.row_num, half).to(device)
        f_agg_y = torch.zeros(self.row_num, half).to(device)

        f_agg_x = f_agg_x.scatter_add_(1, self.receivers.flatten().repeat(self.row_num,1).type(torch.int64), self.f_inter_x).to(device)
        f_agg_y = f_agg_y.scatter_add_(1, self.receivers.flatten().repeat(self.row_num,1).type(torch.int64), self.f_inter_y).to(device)
        
        ######## Self propelled term #########        
        
        sp_x = 1*self.u[:,0:half]*(v0/self.u_mag-1)
        sp_y = 1*self.u[:,half:]*(v0/self.u_mag-1)
        
        ode_x = self.accel[:,0:half]-sp_x-f_agg_x
        ode_y = self.accel[:,half:]-sp_y-f_agg_y

        self.loss_force = torch.sum(torch.square(ode_x)+torch.square(ode_y))/(pred.size(0)*half)
        
        #Loss from velocity
        self.loss_vel_mag = loss_mse(self.u_mag[0,:],torch.tensor([v0], dtype=torch.float32).to(device))
        
        #### Set the data point for ground truth ####
        self.data_pt = torch.arange(0, N_time, N_time//(N_time//10))
        
        
        ######## Order parameter #########
        r_mag_avg = torch.mean(r_mag, dim=1, keepdim=True)
        
        denom = radius*v0
        ang_mom = torch.mean(pos_x*self.u[:,half:] - pos_y*self.u[:,0:half], dim=1, keepdim=True)
        ang_mom_abs = torch.mean(torch.abs(pos_x*self.u[:,half:] - pos_y*self.u[:,0:half]), dim=1, keepdim=True)
        order_para =  ang_mom/denom
        order_para_abs = ang_mom_abs/denom
        
        
        self.loss_ord_abs = loss_mse(order_para_abs[self.data_pt,:],\
                                     torch.tensor([1.], dtype=torch.float32).to(device))
        self.loss_radius = loss_mse(r_mag_avg[self.data_pt,:], \
                                    torch.tensor([radius], dtype=torch.float32).to(device))
        
        loss = 1*self.loss_vel_mag + 5*self.loss_radius + 5*self.loss_ord_abs + 1*self.loss_force

                  
        return loss


In [3]:
######### indexing for interaction ############

n_agents = 40

senders=[]
receivers=[]

for i in range(n_agents):
    for j in range(n_agents-1):
        receivers.append([i])
    for k in range(n_agents):
        if k!=i:
            senders.append([k])

senders_G= torch.tensor(senders, dtype=torch.float32).to(device)  #index of the sender node for edge
receivers_G= torch.tensor(receivers, dtype=torch.float32).to(device)  #index of the receiver node for edge


In [4]:
def closure():
    model.train()
    optimizer_LBFGS.zero_grad()
    pred=model(t)
    loss=model.loss_func(pred, r_a=r_a, r_b=r_b, v0=v0, tau=time_end-dt)

    if torch.isfinite(loss).item:
        loss.backward() 
    else:
        pass

    return loss

In [5]:
 
#### define hyperparameter####

dt = 0.05
time_end = 5.0+dt
v0 = 2
epoch_adam=200
epoch_LBFGS=1000
patience=-21
tolerance=1e-4

##define input
t = torch.arange(0, time_end-eps, dt, dtype=torch.float32)
t = t.reshape(len(t),1).to(device)
t.requires_grad=True
out_num = n_agents * 2
    
r_avg = 4
r_a, r_b = r_avg, r_avg

retry_num = 100

print("Training Started")
for ii in range(retry_num):
    loss_value=[]
    ##define model
    model=PINN(input_num=1, output_num=out_num, receivers=receivers_G, senders=senders_G, radius = r_avg).to(device)

    ## ADAM ##
    for i in range(epoch_adam):
        learning_rate = 0.001
        optimizer_adam=torch.optim.Adam(model.parameters(), lr=learning_rate)

        model.train()
        optimizer_adam.zero_grad()

        pred=model(t)
        loss=model.loss_func(pred, r_a=r_a, r_b=r_b, v0=v0, tau=time_end-dt)

        loss.backward()
        optimizer_adam.step()
        
        loss_value.append(loss.item())
        
        ### Early Stop condition ###
        if i>21:
            if loss_value[-1] <= loss_value[-2] and abs(loss_value[-1]-np.mean(loss_value[patience:-1])) < tolerance:
                break




    ##LBFGS ##
    optimizer_LBFGS=torch.optim.LBFGS(model.parameters(), lr=0.1, max_iter=20, line_search_fn = 'strong_wolfe')

    for i in range(epoch_LBFGS):

        loss_prev = optimizer_LBFGS.step(closure)
        if torch.isfinite(loss_prev) == False:
            break
        loss_value.append(loss_prev.item())
        ### Early Stop condition ###
        if i>21:
            if loss_value[-1] <= loss_value[-2] and abs(loss_value[-1]-np.mean(loss_value[patience:-1])) < tolerance:
                break
                
        
    
    ### Stopping Condition ####
    if model.loss_ord_abs.item()<0.1 and model.loss_force.item()<0.1:
        break

print("Training Stopped")        

print("lambda1 = %.3e \nlambda2 = %.3e \nlambda3 = %.3e \nlambda4 = %.3e" \
      %(model.lambda1, model.lambda2, model.lambda3, model.lambda4))
print("power1 = %.3e \npower2 = %.3e \npower3 = %.3e" \
      %(model.power1, model.power2, model.power3))

Training Started


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Training Stopped
lambda1 = -1.358e-01 
lambda2 = 4.574e-01 
lambda3 = 8.822e-02 
lambda4 = -4.157e-01
power1 = -3.428e-01 
power2 = -8.664e-02 
power3 = -2.425e-01
