In [70]:
import numpy as np
import matplotlib.pyplot as plt

from matplotlib import cm

In [140]:
%matplotlib notebook

In [10]:
def _coeff_matrix_func(x, y, d, a, b):
    return lambda m, n: 4 / (a * b) * np.sum(d[None, None, :] 
                                             * np.sin(np.pi / a * m[:, None, None] * x[None, None, :]) 
                                             * np.sin(np.pi / b * n[None, :, None] * y[None, None, :]),
                                             axis=2)

In [18]:
def _lambda_matrix_func(a, b):
    return lambda m, n: np.pi * np.sqrt((m[:, None] / a)**2 + (n[None, :] / b)**2)

In [50]:
def dissipation_func(x0, y0, d0, a, b, m, n):
    coeff_matrix = _coeff_matrix_func(x0, y0, d0, a, b)
    lambda_matrix = _lambda_matrix_func(a, b)
    
    m_expanded = np.expand_dims(m, axis=(1, 2, 3, 4))
    n_expanded = np.expand_dims(n, axis=(0, 2, 3, 4))
    x_expanded = np.expand_dims(x, axis=(0, 1, 3, 4))
    y_expanded = np.expand_dims(y, axis=(0, 1, 2, 4))
    t_expanded = np.expand_dims(t, axis=(0, 1, 2, 3))
    
    coeff_expanded = np.expand_dims(coeff_matrix(m, n), axis=(2, 3, 4))
    lambda_expanded = np.expand_dims(lambda_matrix(m, n), axis=(2, 3, 4))
    
    return lambda x, y, t: np.sum(coeff_expanded
                                  * np.sin(np.pi / a * m_expanded * x_expanded) 
                                  * np.sin(np.pi / b * n_expanded * y_expanded) 
                                  * np.exp(-lambda_expanded**2 * t_expanded), 
                                 axis=(0, 1))

In [154]:
x0 = np.array([1, 1.2, 2])
y0 = np.array([1.3, 2, 2.3])
d0 = np.array([1, 2, 3])
a, b = 4, 4
m = np.arange(1, 20)
n = np.arange(1, 20)

In [168]:
u_func = dissipation_func(x0, y0, d0, a, b, m, n)

In [169]:
x = np.linspace(0, 4, 25)
y = np.linspace(0, 4, 25)
t = np.linspace(0, 0.1, 10)

In [170]:
u = u_func(x, y, t)

In [171]:
x2, y2 = np.meshgrid(x, y)
u0 = u[:, :, 0]

In [194]:
u0 = u[:, :, 7]

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x0, y0, d0, marker='o', color='black', linewidths=20)

# Normalize the colors based on Z value
norm = plt.Normalize(u0.min(), u0.max())
colors = cm.jet(norm(u0))
surf = ax.plot_surface(y2, x2, u0, facecolors=colors, shade=False)
surf.set_facecolor((0,0,0,0))

<IPython.core.display.Javascript object>

In [173]:
u.shape

(25, 25, 10)