In [1]:
%config IPCompleter.greedy=True

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
import math
import time
import IPython as ip

In [2]:
x_lim, y_lim, v_lim = [0, 1], [0, 1], 0.5
#dimension is equal to position size and velocity size
n_particle, dimension = 100, 2

delta_t, frames = 0.002, 1000
ext_f = torch.tensor([[0, -9.8]] * n_particle)
kernel_range = 0.025
kernel_normalization_factor = [4 / (3 * kernel_range), 40 / (7 * math.pi * pow(kernel_range, 2)), 8 / (math.pi * pow(kernel_range, 3))]

mass = torch.ones(n_particle)

def kernel_func(r_diff, h = kernel_range):
    #r_diff : [2, n_particle, n_particle]
    q = torch.norm(r_diff, dim = 0) / h
    
    mask1 = q <= 0.5
    mask2 = q <= 1
    
    result = torch.zeros(r_diff.shape[1:])
    result += kernel_normalization_factor[dimension - 1] * (6 * (pow(q, 3) - pow(q, 2)) + 1) * mask1
    result += kernel_normalization_factor[dimension - 1] * (2 * pow(1 - q, 3)) * ~mask1 * mask2
    
    return result.sum(0)

def kernel_func_grad(r_diff, h = kernel_range):
    q = torch.norm(r_diff, dim = 0) / kernel_range
    
    mask1 = q <= 0.5
    mask2 = q <= 1
    
    result = torch.zeros(r_diff.shape)
    
    r_diff_without_zero = r_diff[r_diff == 0] = 1
    result += kernel_normalization_factor[dimension - 1] * -6 * r_diff * (2 * h - 3 * torch.abs(r_diff)) / pow(h, 3) * mask1
    result += kernel_normalization_factor[dimension - 1] * -6 * r_diff * pow(h - torch.abs(r_diff), 2) / (pow(h, 3) * torch.abs(r_diff))  * ~mask1 * mask2
    
    return result

rest_density = kernel_func(torch.tensor([[0], [0]], dtype = float))

def visualize(pos_info):
    fig = plt.figure()
    
    x, y = pos_info.T
    plt.scatter(x,y)
    plt.axis([-0.1, 1.1, -0.1, 1.1]) 
    plt.show()

In [13]:
class SPHNet(nn.Module):
    def __init__(self):
        super(SPHNet, self).__init__()
        
    def forward(self, pos, vel):
        result = torch.zeros(pos.shape)
        
        #equation 11
        x, y = pos[:, :1], pos[:, 1:2]
        
        diff_mat = torch.stack([x - x.transpose(0, -1), y - y.transpose(0, -1)], dim=0)
        '''diff_mat = torch.empty(2, pos.shape[0], pos.shape[0], dtype=torch.float)
        diff_mat[0] = x - x.transpose(0, -1)
        diff_mat[1] = y - y.transpose(0, -1)'''
        density = kernel_func(diff_mat) * mass
        
        #equation 23
        stacked_mass = torch.stack([mass, mass], dim = 1)
        vel_prime = vel + delta_t / stacked_mass * ext_f
        
        #need to implement viscosity later
        
        
        
        #equation 19
        pressure = 10 * (pow((density / rest_density), 7) - torch.ones(n_particle))
        
        kernel_grad = kernel_func_grad(diff_mat)
        
        
        term_1 = pressure / density * torch.matmul(kernel_grad, mass)
        term_2 = density * (mass * pressure / pow(density, 2) * kernel_grad).sum(1)
        
        pressure_force = (-(term_1 + term_2) / density).transpose(0, -1)
        
        #update pos and vel
        
        new_vel = vel_prime + delta_t / stacked_mass * pressure_force
        new_pos = pos + delta_t * new_vel
        
        new_pos = new_pos.transpose(0, -1)
        new_vel = new_vel.transpose(0, -1)
        
        new_pos[0] = torch.clamp(new_pos[0], min = x_lim[0], max = x_lim[1])
        new_pos[1] = torch.clamp(new_pos[1], min = y_lim[0], max = y_lim[1])
        
        new_vel[0] = torch.clamp(new_vel[0], min = -v_lim, max = v_lim)
        new_vel[1] = torch.clamp(new_vel[1], min = -v_lim, max = v_lim)
        
        new_pos = new_pos.transpose(0, -1)
        new_vel = new_vel.transpose(0, -1)
        return new_pos, new_vel

In [14]:
pos, vel = torch.zeros([n_particle, dimension], dtype = float), torch.zeros([n_particle, dimension], dtype = float)


for i in range(n_particle):
    pos[i] = torch.tensor([random.uniform(x_lim[0], x_lim[1]), random.uniform(y_lim[0], y_lim[1])], dtype = float)
    #vel[i] = torch.tensor([random.uniform(0, v_lim), random.uniform(0, v_lim)], dtype = float)

pos_list, vel_list = [pos], [vel]

sph = SPHNet()


for i in range(frames):
    if i % 100 == 0:
        print(i, 'th frame')
    pos, vel = sph(pos, vel)
    pos_list.append(pos)
    

for i in range(len(pos_list)):
    ip.display.clear_output(wait=True)
    visualize(pos_list[i])
    time.sleep(delta_t)

KeyboardInterrupt: 