In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

project_path = "./examples/elastic/Iso-elatic-Marmousi2/data"

iter_vp = np.load(os.path.join(project_path,"inversion/iter_vp.npz"))["data"][::3]
iter_vs = np.load(os.path.join(project_path,"inversion/iter_vs.npz"))["data"][::3]
iter_rho = np.power(iter_vp, 0.25) * 310

In [None]:
###########################################
# visualize the inversion results
###########################################
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import os

# Set up the figure for plotting
fig, ax = plt.subplots(1, 3, figsize=(15, 6))  # Increase figsize for better spacing
cax1 = ax[0].imshow(iter_vp[0], aspect='equal', cmap='jet_r', vmin=iter_vp.min(), vmax=iter_vp.max())
cax2 = ax[1].imshow(iter_vs[0], aspect='equal', cmap='jet_r', vmin=iter_vs.min(), vmax=iter_vs.max())
cax3 = ax[2].imshow(iter_rho[0], aspect='equal', cmap='jet_r', vmin=iter_rho.min(), vmax=iter_rho.max())

# Create horizontal colorbars with adjustments
cbar1 = fig.colorbar(cax1, ax=ax[0], orientation='horizontal', fraction=0.046, pad=0.1, shrink=0.8)
cbar1.set_label('Velocity (m/s)', fontsize=10)
cbar2 = fig.colorbar(cax2, ax=ax[1], orientation='horizontal', fraction=0.046, pad=0.1, shrink=0.8)
cbar2.set_label('Velocity (m/s)', fontsize=10)
cbar3 = fig.colorbar(cax3, ax=ax[2], orientation='horizontal', fraction=0.046, pad=0.1, shrink=0.8)
cbar3.set_label('Density (kg/m³)', fontsize=10)

# Set titles for each subplot
ax[0].set_title('P-wave Velocity', fontsize=12)
ax[1].set_title('S-wave Velocity', fontsize=12)
ax[2].set_title('Density', fontsize=12)

# Adjust layout to ensure centering
plt.subplots_adjust(top=0.85, bottom=0.35, left=0.1, right=0.9)

# Initialization function
def init():
    cax1.set_array(iter_vp[0])
    cax2.set_array(iter_vs[0])
    cax3.set_array(iter_rho[0])
    return cax1, cax2, cax3

# Animation function
def animate(i):
    cax1.set_array(iter_vp[i])
    cax2.set_array(iter_vs[i])
    cax3.set_array(iter_rho[i])
    return cax1, cax2, cax3

# Create the animation
ani = animation.FuncAnimation(fig, animate, init_func=init, frames=len(iter_vp), interval=200, blit=True)

# Save the animation as a GIF file
ani.save(os.path.join(project_path, "inversion/inversion_process.gif"), writer='pillow', fps=10)

# Display the animation using HTML
plt.close(fig)  # Prevents static display of the last frame
HTML(ani.to_jshtml())