In [1]:
import matplotlib as mpl
import numpy as np
import matplotlib.pyplot as plt
#%matplotlib inline

import pickle
from models.mesh_model import MeshModel, Abundance, rotation, mesh_model, pulsation, model_spectrum, abundance_spot
from models.mesh_generation import apply_spherical_harm_pulsation, mesh_polar_vertices, vertex_to_polar
import jax.numpy as jnp
import jax
from celluloid import Camera
LOG_WAVELENGTHS = jnp.linspace(jnp.log10(5355), jnp.log10(5375), 10000)

Models defined.




In [2]:
%matplotlib inline

In [3]:
abundances = np.ones(17)

In [4]:
MIN_TEFF = np.log10(2500.0)
MAX_TEFF = np.log10(7999.9)
TIMESTAMPS = jnp.linspace(0., 1., 10)

In [5]:
# m = mesh_model(n_vertices=4000,
#                teff=0.9,
#                logg=0.8, vturb=1.,
#                abundances=abundances*0.1,
#                timestamps=jnp.linspace(0., 10., 50))
# # Może nie ma się czym przejmować z timestamps i działać na fazach

In [None]:
abundances_Fe = jnp.array([*jnp.ones(16)*0.1, 0.5])
m_Fe = mesh_model(n_vertices=1000,
               teff=0.9,
               logg=0.8, vturb=1.,
               abundances=abundances_Fe,
               timestamps=TIMESTAMPS)

In [None]:
s = model_spectrum(m, 0, LOG_WAVELENGTHS)
s_Fe = model_spectrum(m_Fe, 0, LOG_WAVELENGTHS)

In [None]:
LOG_WAVELENGTHS2 = jnp.linspace(jnp.log10(5440), jnp.log10(5470), 10000)
s_Fe2 = model_spectrum(m_Fe, 0, LOG_WAVELENGTHS2)
plt.plot(jnp.power(10, LOG_WAVELENGTHS2), s_Fe2[0]/s_Fe2[1], color='black')
plt.gca().set_xlabel('Wavelength [$\AA$]');
plt.gca().set_ylabel('Normalized flux');

In [None]:
plt.plot(jnp.power(10, LOG_WAVELENGTHS), s[0]/s[1],
         color='black', label='Low Fe')
plt.plot(jnp.power(10, LOG_WAVELENGTHS), s_Fe[0]/s_Fe[1],
         color='red', alpha=0.5, label='High Fe')
plt.gca().set_xlabel('Wavelength [$\AA$]');
plt.gca().set_ylabel('Normalized flux');
plt.legend();

In [None]:
m_Fe_Rot = rotation(m_Fe, rotation_velocity=50)

In [None]:
s_Fe_Rot = model_spectrum(m_Fe_Rot, 0, LOG_WAVELENGTHS)

In [None]:
%matplotlib inline
plt.plot(jnp.power(10, LOG_WAVELENGTHS), s_Fe[0]/s_Fe[1],
         color='black', label='no rotation')
plt.plot(jnp.power(10, LOG_WAVELENGTHS), s_Fe_Rot[0]/s_Fe_Rot[1],
         color='red', label='50 km/s')
plt.legend();
plt.gca().set_xlabel('Wavelength [$\AA$]')
plt.gca().set_ylabel('Normalized flux');

In [None]:
time_index = 0
fig, ax = plt.subplots(figsize=(6, 5))
mu_mask = m.mus[time_index]>0
s1 = ax.scatter(m.centers[time_index, mu_mask, 0],
                m.centers[time_index, mu_mask, 2],
                s=2., c=m.mus[time_index, mu_mask],
                cmap='magma')
cbar = fig.colorbar(s1)
cbar.set_label('$\mu$')
ax.set_xlabel('X [$R_\odot$]')
ax.set_ylabel('Z [$R_\odot$]')

# Rotation

In [None]:
m10 = rotation(m, rotation_velocity = 10.)

time_index = 0
fig, ax = plt.subplots(figsize=(6, 5))
mu_mask = m.mus[time_index]>0
s1 = ax.scatter(m10.centers[time_index, mu_mask, 0], m10.centers[time_index, mu_mask, 2],
                s=2., c=m10.los_velocities[time_index, mu_mask], cmap='turbo')
cbar = fig.colorbar(s1)
cbar.set_label('LOS velocity [km/s]')
ax.set_xlabel('X [$R_\odot$]')
ax.set_ylabel('Z [$R_\odot$]')

In [None]:
m10_i = rotation(m, rotation_velocity = 10., inclination=jnp.array([1., 0., 0.]))

time_index = 0
fig, ax = plt.subplots(figsize=(6, 5))
mu_mask = m.mus[time_index]>0
s1 = ax.scatter(m10_i.centers[time_index, mu_mask, 0],
                m10_i.centers[time_index, mu_mask, 1],
                s=2., c=m10_i.los_velocities[time_index, mu_mask], cmap='turbo')
cbar = fig.colorbar(s1)
cbar.set_label('LOS velocity [km/s]')
ax.set_xlabel('X [$R_\odot$]')
ax.set_ylabel('Z [$R_\odot$]')

# Pulsations

In [None]:
from models.mesh_model import vec_apply_spherical_harm_pulsation
def __pulsation(mesh_model: MeshModel,
              m: float, n: float,
              magnitude: float,
              t0: float, period: float) -> MeshModel:
    phases = (mesh_model.timestamps[:, jnp.newaxis]-t0)/period
    amplifications = magnitude*jnp.sin(phases)
    velocities = magnitude*jnp.cos(phases)
#     return phases, velocities
    vert_offsets, center_offsets, area_offsets, sph_ham = vec_apply_spherical_harm_pulsation(mesh_model.vertices, mesh_model.centers,
                                                                                             mesh_model.faces,
                                                                                             mesh_model.areas,
                                                                                             amplifications, m, n)
    puls_velocities = mesh_model.velocities+mesh_model.centers*velocities[:, jnp.newaxis]*sph_ham
    return MeshModel(timestamps=mesh_model.timestamps,
                     los_vector=mesh_model.los_vector,
                     radius=mesh_model.radius,
                     mass=mesh_model.mass,
                     teffs=mesh_model.teffs,
                     logg=mesh_model.logg,
                     vturb=mesh_model.vturb,
                     abundances=mesh_model.abundances,
                     vertices=mesh_model.vertices+vert_offsets,
                     centers=mesh_model.centers+center_offsets,
                     faces=mesh_model.faces,
                     areas=mesh_model.areas+area_offsets,
                     mus=mesh_model.mus,
                     velocities=puls_velocities,
                     los_velocities=mesh_model.los_velocities+sph_ham*velocities[:, jnp.newaxis, :]*mesh_model.mus[:, :, jnp.newaxis])

In [None]:
mp = __pulsation(m_Fe, 1, 1, 1., 0., 1.)

In [None]:
jnp.min(s), jnp.max(s)

In [None]:
mp = pulsation(m_Fe, 1, 1, 1., 0., .1)
# norm_velocities = jnp.linalg.norm(mp.velocities, axis=2)
# dot_velocities = jnp.dot(mp.velocities, mp.los_vector)
# fig, ax = plt.subplots(1, 3, figsize=(15, 4))
# time_index = 9
# mu_mask = mp.mus[time_index]>0
# s1 = ax[0].scatter(mp.centers[time_index, mu_mask, 0], mp.centers[time_index, mu_mask, 2],
#                 s=6.,
#                 c=mp.los_velocities[time_index, mu_mask],
#                 cmap='turbo')
# s2 = ax[1].scatter(mp.centers[time_index, mu_mask, 0], mp.centers[time_index, mu_mask, 2],
#                 s=6.,
#                 c=norm_velocities[time_index, mu_mask],
#                 cmap='turbo')
# cbar = fig.colorbar(s1, ax=ax[0])
# cbar = fig.colorbar(s2, ax=ax[1])

In [None]:
fig = plt.figure(figsize=(10, 9))
camera = Camera(fig)
spec = fig.add_gridspec(7, 9)
ax = fig.add_subplot(spec[:5, :8], projection='3d')
cbar_ax = fig.add_subplot(spec[1:4, 8])
spectr_ax = fig.add_subplot(spec[5:, 1:])
camera = Camera(fig)
axes_lim = 1.5*mp.radius
ax.set_xlim3d(-axes_lim, axes_lim)
ax.set_ylim3d(-axes_lim, axes_lim)
ax.set_zlim3d(-axes_lim, axes_lim)
cmap = 'magma'
cool_cmap = mpl.cm.get_cmap('coolwarm')

dot_velocities = jnp.dot(mp.velocities, mp.los_vector)

norm = mpl.colors.Normalize(vmin=jnp.min(dot_velocities),
                            vmax=jnp.max(dot_velocities))
for phase_index in range(len(mp.centers)):
    # ax.quiver(*(-1.5*mesh.radius*m.los_vector), *m.los_vector, color='red', linewidth=3.)
    # ax.quiver(0., 0., 0., *mesh.rotation_axis, color='black', linewidth=3.)
    centers = mp.centers[phase_index]

    p = ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2],
                   cmap=cmap, norm=norm,
                   c=dot_velocities[phase_index])
    ax.quiver(*(-3*mp.los_vector), *mp.los_vector, color='red')
    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cbar_ax)
    camera.snap()
    

In [None]:
animation = camera.animate()
animation.save('puls.gif')

In [None]:
LOG_WAVELENGTHS_LINE = jnp.linspace(jnp.log10(5362.5),
                                    jnp.log10(5363.25), 10000)
sl = model_spectrum(m_Fe, 0, LOG_WAVELENGTHS_LINE)
sp = model_spectrum(mp, 0, LOG_WAVELENGTHS_LINE)
plt.plot(jnp.power(10, LOG_WAVELENGTHS_LINE), sl[0]/sl[1],
         color='black', label='no pulsation')
plt.plot(jnp.power(10, LOG_WAVELENGTHS_LINE), sp[0]/sp[1],
         color='red', label='pulsation')
plt.legend();
plt.gca().set_xlabel('Wavelength [$\AA$]')
plt.gca().set_ylabel('Normalized flux');

# Spots

In [None]:
mt = rotation(abundance_spot(m_Fe, 1.5, 1.5, .1, .5, 0.75, Abundance.Fe.value-1),
              rotation_velocity=10.,
              inclination=jnp.array([0., 0., 1.]))

In [None]:
jnp.max(mt.abundances[:, 16])

In [None]:
mt.los_velocities[0]

In [None]:
cool_cmap = mpl.cm.get_cmap('coolwarm')

In [None]:
cool_cmap

In [None]:
mt = rotation(abundance_spot(m_Fe, 0.5, 1., .1, 1., 0.75, Abundance.Fe.value-1),
              rotation_velocity=5.,
              inclination=jnp.array([0., 0., 1.]))

ZOOMED_IN_LINE = np.linspace(np.log10(5361.), np.log10(5363.5), 1000)
fig, ax = plt.subplots(1, 1, figsize=(9, 5))
for phase_ind, cmap_i in enumerate(np.linspace(0, 1, 50)):
    atmo = model_spectrum(mt, int(phase_ind), ZOOMED_IN_LINE)
    ax.plot(np.power(10, ZOOMED_IN_LINE), atmo[0]/atmo[1], color=cool_cmap(cmap_i), alpha=0.5)
ax.set_xlabel('Wavelength [$\AA$]')
ax.set_ylabel('Normalized flux');

In [None]:
mt = rotation(abundance_spot(m_Fe, 0.5, 1., .1, 1., 0.75, Abundance.Fe.value-1),
              rotation_velocity=5.,
              inclination=jnp.array([0., 0., 1.]))

ZOOMED_IN_LINE = np.linspace(np.log10(5361.), np.log10(5363.5), 1000)
fig, ax = plt.subplots(1, 1, figsize=(9, 5))
for phase_ind, cmap_i in zip(range(12), np.linspace(0, 1., 12)):
    atmo = model_spectrum(mt, int(phase_ind), ZOOMED_IN_LINE)
    ax.plot(np.power(10, ZOOMED_IN_LINE), atmo[0]/atmo[1], color=cool_cmap(cmap_i), alpha=0.5)
ax.set_xlabel('Wavelength [$\AA$]')
ax.set_ylabel('Normalized flux');

In [None]:
from matplotlib.collections import LineCollection

mt = rotation(abundance_spot(m_Fe, 1.5, 1.5, .5, .5, 0.75, Abundance.Fe.value-1),
              rotation_velocity=5.,
              rotation_axis=jnp.array([0., 0., 1.]))
ZOOMED_IN_LINE = np.linspace(np.log10(5362.5), np.log10(5363.25), 1000)
fig = plt.figure(figsize=(12, 5))
spec = fig.add_gridspec(10, 13)

axes = []

axes.append(fig.add_subplot(spec[:5, :3]))
ax1 = axes[0]
axes.append(fig.add_subplot(spec[:5, 3:6], sharex=ax1, sharey=ax1))
axes.append(fig.add_subplot(spec[:5, 6:9], sharex=ax1, sharey=ax1))
axes.append(fig.add_subplot(spec[:5, 9:12], sharex=ax1, sharey=ax1))
cbar_ax = fig.add_subplot(spec[:5, -1])
spectr_ax = fig.add_subplot(spec[6:, 1:-2])
spectr_cbar_ax = fig.add_subplot(spec[6:, -1])
cool_cmap = mpl.cm.get_cmap('coolwarm')

norm = mpl.colors.Normalize(vmin=jnp.min(mt.abundances[:, 16]),
                            vmax=jnp.max(mt.abundances[:, 16]))
norm_s = mpl.colors.Normalize(vmin=4,
                            vmax=8)

segments = []

for ax_i, i in enumerate([4, 5, 7, 8]):
    mus = mt.mus[i]
    mu_mask = mus>0
    im = axes[ax_i].scatter(mt.centers[i, mu_mask, 0], mt.centers[i, mu_mask, 2], c=mt.abundances[mu_mask, 16], norm=norm, cmap='magma', s=1)
    atmo = model_spectrum(mt, i, ZOOMED_IN_LINE)
    s_im = spectr_ax.plot(np.power(10, ZOOMED_IN_LINE), atmo[0]/atmo[1], color=cool_cmap(norm_s(i)))
    segments.append(np.column_stack([np.power(10, ZOOMED_IN_LINE), atmo[0]/atmo[1]]))
    if ax_i!=0:
        axes[ax_i].set_yticklabels([])

lc = LineCollection(segments, cmap='coolwarm')
axes[0].set_ylabel('Y [$R_\odot$]')
spectr_ax.set_xlabel('Wavelength [$\AA$]')
spectr_ax.set_ylabel('Normalized flux')
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.set_label('Fe abundance (normalized)')
s_cbar = fig.colorbar(lc, cax=spectr_cbar_ax, cmap=cool_cmap)
s_cbar.set_label('Phase')
s_cbar.set_ticks([0., 0.5, 1.])
s_cbar.set_ticklabels(['0.0', '$\\frac{\pi}{2}$','$\pi$'])

In [None]:
fig.savefig('fe_spot.png')

In [None]:
from models.mesh_model import teff_spot
mt = rotation(abundance_spot(m_Fe, 0.5, 1., .1, 1., 0.75, Abundance.Fe.value-1),
              rotation_velocity=5.,
              inclination=jnp.array([0., 0., 1.]))

los_vels = np.nan_to_num(np.array(mt.los_velocities))
norm = mpl.colors.Normalize(vmin=jnp.min(mt.abundances[:, 16]),
                            vmax=jnp.max(mt.abundances[:, 16]))

fig = plt.figure(figsize=(10, 9))
camera = Camera(fig)
spec = fig.add_gridspec(7, 9)
ax = fig.add_subplot(spec[:5, :8], projection='3d')
cbar_ax = fig.add_subplot(spec[1:4, 8])
spectr_ax = fig.add_subplot(spec[5:, 1:])

axes_lim = 1.5*mt.radius
ax.set_xlim3d(-axes_lim, axes_lim)
ax.set_ylim3d(-axes_lim, axes_lim)
ax.set_zlim3d(-axes_lim, axes_lim)
cmap = 'magma'

for phase_index in range(len(mt.centers)):
    # ax.quiver(*(-1.5*mesh.radius*m.los_vector), *m.los_vector, color='red', linewidth=3.)
    # ax.quiver(0., 0., 0., *mesh.rotation_axis, color='black', linewidth=3.)
    centers = mt.centers[phase_index]

    p = ax.scatter(centers[:, 0], centers[:, 1], centers[:, 2],
                   cmap=cmap, norm=norm,
                   #c=los_vels[phase_index])
                   c=mt.abundances[:, 16])
                   #c=(jnp.linalg.norm(m.pulsation_velocities[phase_index], axis=1)*m.get_mus_for_time(phase_index)), cmap=cmap, norm=norm)

    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cbar_ax)
    cbar.set_label('Fe abundance')
    atmo = model_spectrum(mt, phase_index, ZOOMED_IN_LINE)
    spectr_ax.plot(np.power(10, ZOOMED_IN_LINE), atmo[0], color='black')
    camera.snap()

In [None]:
animation = camera.animate()
animation.save('spot.gif')

In [None]:
fig, ax = plt.subplots()