In [358]:
'''
======================
3D EBM ANIMATION
======================
'''
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from matplotlib.animation import FuncAnimation, PillowWriter  
from mpl_toolkits.mplot3d import Axes3D

import autograd.numpy as np
from autograd import grad

import scipy
from scipy.optimize import fsolve
from  scipy import ndimage

from tqdm.notebook import tqdm

# function that generates the data on the x , y plane
def data_eq(x):    
    a = 0.5
    b = 0
    c = -20
    y = a*x**2 + b*x + c 
    return y


def distance_func(x,P):
    "Distance of a point P from curve"
    _x, _y = P 
    return (x-_x)**2 + (data_eq(x)-_y)**2

dfdx = grad(distance_func)

def get_distance(P):
    """ P = (x,y)
    calculating the shortest distance between an arbitrar point P 
    and data represented by equation data_eq
    # more details check : http://kitchingroup.cheme.cmu.edu/blog/category/optimization/
    # https://math.stackexchange.com/questions/2264702/shortest-distance-from-point-to-curve
    """
    start_x =  P[0]+10 if P[0] > 0 else P[0]-10 
    x = fsolve(dfdx, start_x, args=P)
    y = data_eq(x)
    d = distance_func(x,P)
    return d 


# DATA points
r_data = 6.5  #arbitrary number < r 
X_data = np.arange(-r_data, r_data, 0.2)
Y_data = np.array([data_eq(i) for i in X_data])
# add a bit of noise to the data 
Y_data = np.random.rand(*Y_data.shape)*(Y_data.mean()/10) + Y_data
# Z axis = 0 for all data 
Z_data = np.array([0]*len(X_data))
# but add a bit of noise
Z_data = np.random.rand(*Z_data.shape)*(20) + Z_data


## EBM 
r = 10
X = np.arange(-r, r, 1)
Y = np.arange(-2*r, 2*r, 1)
X, Y = np.meshgrid(X, Y)
R = [get_distance([x,y]) for x,y in zip(X.flatten(),Y.flatten())]
Z = np.array(R).reshape(X.shape)

def enhance_visually(X):
    noise = np.random.rand(*X.shape)
    noise = (X.max()-X.min()) * noise * 0.01    ## scale noise to 0.01 of max value
    X = X + noise
    # smooth Z
    sigma = [2,2]
    X = ndimage.filters.gaussian_filter(X, sigma)
    return X 

Z = enhance_visually(Z)

EBM_updates = 100
CAM_updates = 100

frames = EBM_updates + CAM_updates  ## update ebm + update camera

# Generate each frame
for n in tqdm(range(frames)):
    # first frames EBM is developing
    if n <= EBM_updates-1:
        ebm_u = n
        cam = 0 
    # later camera moving
    else:
        cam = n-EBM_updates
        ebm_u = EBM_updates-1
        
    power = np.arange(0,0.8, step=0.8/EBM_updates)
    mult =  np.arange(1,6,5/EBM_updates)
    azims =  np.arange(-79,-99,-19/CAM_updates)
    elevs =  np.arange(30,46,15/CAM_updates)
    
    plt.style.use('dark_background')
    fig = plt.figure(figsize=(15,15))
    ax = fig.gca(projection='3d')
    ax.axis('off')
    plt.xlim(-11, 11)
    plt.ylim(-21, 21)
    ax.set_zlim(-1,140)

    ax.azim = azims[cam]
    ax.dist = 10
    ax.elev = elevs[cam]
    
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    
    Z_step = Z**power[ebm_u] * mult[ebm_u]
    
    # Plot the EBM
    Z_step_sexy = Z_step - 20
    surf = ax.plot_surface(X, Y, Z_step_sexy, cmap=cm.plasma,
                       linewidth=1,antialiased=True, vmin = -25, vmax = 120,alpha=1, zorder=2, shade=True)

    
    # Plot the data points
    Y_data_sexy = Y_data-10  # looks nicer this way
    Z_data_sexy = Z_data + 51 # make sure the EBM is not above data
    ax.scatter(X_data, Y_data_sexy, Z_data_sexy, marker="o", s=28, c="deeppink",alpha=0.8, zorder=2.5)

    plt.tight_layout()
    plt.savefig(str(n)+'.png', pad_inches=0)
    plt.close()


fwd = list(range(frames))
bwd = list(range(frames))
bwd.reverse()


for n in fwd + bwd:
    exec('a'+str(n)+'=Image.open("'+str(n)+'.png")')
    images.append(eval('a'+str(n)))
images[0].save('EBM.gif',
               save_all=True,
               append_images=images[1:],
               duration=50,
               loop=0)

HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




AttributeError: 'Array' object has no attribute 'save'