In [None]:
%matplotlib Qt5
import matplotlib.pyplot as plt
from matplotlib import font_manager

import numpy as np

from einops import rearrange, repeat

from twoBodies import TwoBody
from data import kinetic_energy, potential_energy, total_energy, state_generator
# hype parameter

maxIter: int   = 50

np.random.seed(114514)
np.set_printoptions(threshold = np.inf)

fontsize = font_manager.FontProperties(size = 9)
tick_fontsize = font_manager.FontProperties(size = 8)
title_fontsize = font_manager.FontProperties(size = 8)
legned_fontsize = font_manager.FontProperties(size = 5.5)

row, col = 2, 3

fig = plt.figure(figsize=[24*col,24*row], dpi=400)

pic_idx = 0

# attn generate initial state
nbodies     = 2
mass        = 1.0
min_radius  = 1.0
max_radius  = 1.0
orbit_noise = 0.05
init_state = state_generator(nbodies = nbodies, mass = mass, min_radius = min_radius, max_radius = max_radius, orbit_noise = orbit_noise)
print(f'init_state:\n{init_state}')

subgraph_item = 0 + ord('a') - 1
item_pos = (-0.07, 1.15)

tight_layout_arg = dict(
    top=0.945,
    bottom=0.111,
    left=0.062,
    right=0.998,
    hspace=0.47,
    wspace=0.215
)

In [None]:
tStart     = 0.
tStep      = 100
interval   = 0.05
maxIter    = 50

qscale, escale = 0.5, 0.5

model = TwoBody(
    tStart      = tStart,
    tStep       = tStep,
    interval    = interval,
    init_state  = init_state,
)

trajs, derror, verror, traderror, traverror = model.init_loop(maxIter)
times = model.times.copy()
print(f'init_state:\n{trajs[0]}')

potential = potential_energy(trajs)
kinetic = kinetic_energy(trajs)
total = total_energy(trajs)

gtpotential = potential_energy(model.orbit)
gtkinetic = kinetic_energy(model.orbit)
gttotal = total_energy(model.orbit)

color = np.linspace(0., 1., times.shape[0])
color = repeat(color, "n -> n repeat", repeat = 3)
color[:, 0] = 0.5
color[:, 2] = color[:, 2][::-1] *0.5 + 0.25
color[:, 1] = color[:, 1] *0.5 + 0.25

trajs, gt = map(lambda x: rearrange(x, "ntrajs np attr -> attr np ntrajs"), (trajs, model.orbit)) # * [[mass, x, y, vx, vy], ntrajs, np]
x_min = trajs[1].min(); x_max = trajs[1].max(); y_min = trajs[2].min(); y_max = trajs[2].max()
tmin, tmax = times.min(), times.max()
e_min = min(potential.min(), kinetic.min(), total.min(), 0)
e_max = max(potential.max(), kinetic.max(), total.max())
print(f'{trajs.shape = }, {gt.shape = }')

mass, x, y, vx, vy = trajs; gtmass, gtx, gty, gtvx, gtvy = gt

# ! Picture idx
pic_idx += 1
subgraph_item += 1

ax = plt.subplot(row, col, pic_idx)
# attn Absolute Value of Error t(0)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')
plt.xlabel(r'$\mathrm{q_x}$', fontproperties = fontsize); plt.ylabel(r'$\mathrm{q_y}$', fontproperties = fontsize, rotation = 0)
plt.title(f'Trajectory', loc = 'center', fontproperties = title_fontsize)

plt.scatter(x[0], y[0], s = times*1.5, label = r'EdSr $\mathrm{P_1}$')
plt.scatter(x[1], y[1], s = times*1.5, label = r'EdSr $\mathrm{P_2}$')

plt.scatter(gtx[0], gty[0], s = times*0.5, label = r'GT $\mathrm{P_1}$')
plt.scatter(gtx[1], gty[1], s = times*0.5, label = r'GT $\mathrm{P_2}$')

plt.xlim(x_min - qscale * abs(x_min), x_max + qscale * abs(x_max))
plt.ylim(y_min - qscale * abs(y_min) - 0.5, y_max + qscale * abs(y_max))

plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
if tight_layout_arg is not None:
    plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)


# ! Picture idx
pic_idx += 1
subgraph_item += 1

ax = plt.subplot(row, col, pic_idx)
# # attn energy t(0)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')
plt.xlabel(r'$\mathrm{\Delta t}$', fontproperties = fontsize); # plt.ylabel(r'E', fontproperties = fontsize, rotation = 0)
plt.title(f'Energy', loc = 'center', fontproperties = title_fontsize)

print(potential.shape, times.shape)

plt.scatter(times, potential, s = times*2, label = r'EdSr $\mathrm{E_{Pot}}$')
plt.scatter(times, total, s = times*2, label = r'EdSr $\mathrm{E_{Tot}}$')
plt.scatter(times, kinetic, s = times*2, label = r'EdSr $\mathrm{E_{Kin}}$')

plt.scatter(times, gtpotential, s = times*0.5, label = r'GT $\mathrm{E_{Pot}}$')
plt.scatter(times, gttotal, s = times*0.5, label = r'GT $\mathrm{E_{Tot}}$')
plt.scatter(times, gtkinetic, s = times*0.5, label = r'GT $\mathrm{E_{Kin}}$')


plt.xticks(times[::20])
plt.ylim(y_min - escale * abs(y_min) + 0.5, y_max + escale * abs(y_max) - 1.2)

plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
if tight_layout_arg is not None:
    plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

# ! Picture idx
pic_idx += 1
subgraph_item += 1

ax = plt.subplot(row, col, pic_idx)
# attn Absolute Value of Error t(0)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')
plt.xlabel(r'$\mathrm{\Delta t}$', fontproperties = fontsize); # plt.ylabel(r'error(unitless)', fontproperties = fontsize)
plt.title(f'MAE compared with GT', loc = 'center', fontproperties = title_fontsize)
plt.yscale('log')
plt.plot(times, derror, label = 'EdSr $q$')
plt.plot(times, verror, label = 'EdSr $v$')
plt.plot(times, traderror, label = 'VV $q$')
plt.plot(times, traverror, label = 'VV $v$')


plt.xticks(times[::20])
plt.yticks(np.logspace(-8, 0, 5))
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
if tight_layout_arg is not None:
    plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

In [None]:
tStart     = 0.
tStep      = 100
interval   = -0.05
maxIter    = 500

qscale, escale = 0.5, 0.5

model = TwoBody(
    tStart      = tStart,
    tStep       = tStep,
    interval    = interval,
    init_state  = init_state,
)

trajs, derror, verror, traderror, traverror = model.init_loop(maxIter)
times = model.times.copy()
print(f'init_state:\n{trajs[0]}')

abs_times = np.abs(times)

potential = potential_energy(trajs)
kinetic = kinetic_energy(trajs)
total = total_energy(trajs)

gtpotential = potential_energy(model.orbit)
gtkinetic = kinetic_energy(model.orbit)
gttotal = total_energy(model.orbit)

color = np.linspace(0., 1., times.shape[0])
color = repeat(color, "n -> n repeat", repeat = 3)
color[:, 0] = 0.5
color[:, 2] = color[:, 2][::-1] *0.5 + 0.25
color[:, 1] = color[:, 1] *0.5 + 0.25

trajs, gt = map(lambda x: rearrange(x, "ntrajs np attr -> attr np ntrajs"), (trajs, model.orbit)) # * [[mass, x, y, vx, vy], ntrajs, np]
x_min = trajs[1].min(); x_max = trajs[1].max(); y_min = trajs[2].min(); y_max = trajs[2].max()
tmin, tmax = times.min(), times.max()
e_min = min(potential.min(), kinetic.min(), total.min(), 0)
e_max = max(potential.max(), kinetic.max(), total.max())
print(f'{trajs.shape = }, {gt.shape = }')

mass, x, y, vx, vy = trajs; gtmass, gtx, gty, gtvx, gtvy = gt

# ! Picture idx
pic_idx += 1
subgraph_item += 1

ax = plt.subplot(row, col, pic_idx)
# attn Absolute Value of Error t(0)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')
plt.xlabel(r'$\mathrm{q_x}$', fontproperties = fontsize); plt.ylabel(r'$\mathrm{q_y}$', fontproperties = fontsize, rotation = 0)
plt.title(f'Trajectory', loc = 'center', fontproperties = title_fontsize)

plt.scatter(x[0], y[0], s = abs_times*1.5, label = r'EdSr $\mathrm{P_1}$')
plt.scatter(x[1], y[1], s = abs_times*1.5, label = r'EdSr $\mathrm{P_2}$')

plt.scatter(gtx[0], gty[0], s = abs_times*0.5, label = r'GT $\mathrm{P_1}$')
plt.scatter(gtx[1], gty[1], s = abs_times*0.5, label = r'GT $\mathrm{P_2}$')

plt.xlim(x_min - qscale * abs(x_min), x_max + qscale * abs(x_max))
plt.ylim(y_min - qscale * abs(y_min) - 1.0, y_max + qscale * abs(y_max))

plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
if tight_layout_arg is not None:
    plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)


# ! Picture idx
pic_idx += 1
subgraph_item += 1

ax = plt.subplot(row, col, pic_idx)
# # attn energy t(0)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')
plt.xlabel(r'$\mathrm{\Delta t}$', fontproperties = fontsize); # plt.ylabel(r'E', fontproperties = fontsize, rotation = 0)
plt.title(f'Energy', loc = 'center', fontproperties = title_fontsize)

print(potential.shape, times.shape)
ax.xaxis.set_inverted(True)
plt.scatter(times, potential, s = abs_times*2, label = r'EdSr $\mathrm{E_{Pot}}$')
plt.scatter(times, total, s = abs_times*2, label = r'EdSr $\mathrm{E_{Tot}}$')
plt.scatter(times, kinetic, s = abs_times*2, label = r'EdSr $\mathrm{E_{Kin}}$')

plt.scatter(times, gtpotential, s = abs_times*0.5, label = r'GT $\mathrm{E_{Pot}}$')
plt.scatter(times, gttotal, s = abs_times*0.5, label = r'GT $\mathrm{E_{Tot}}$')
plt.scatter(times, gtkinetic, s = abs_times*0.5, label = r'GT $\mathrm{E_{Kin}}$')


plt.xticks(times[::20])
plt.ylim(y_min - escale * abs(y_min) + 1.2, y_max + escale * abs(y_max) - 1.7)

plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
if tight_layout_arg is not None:
    plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

# ! Picture idx
pic_idx += 1
subgraph_item += 1

ax = plt.subplot(row, col, pic_idx)
# attn Absolute Value of Error t(0)
plt.text(*item_pos, chr(subgraph_item), transform = ax.transAxes, fontsize = fontsize.get_size(), fontweight = 'bold', va = 'top', ha = 'left')
plt.xlabel(r'$\mathrm{\Delta t}$', fontproperties = fontsize); # plt.ylabel(r'error(unitless)', fontproperties = fontsize)
plt.title(f'MAE compared with GT', loc = 'center', fontproperties = title_fontsize)
plt.yscale('log')
ax.xaxis.set_inverted(True)
plt.plot(times, derror, label = 'EdSr $q$')
plt.plot(times, verror, label = 'EdSr $v$')
plt.plot(times, traderror, label = 'VV $q$')
plt.plot(times, traverror, label = 'VV $v$')

plt.xticks(times[::20])
plt.yticks(np.logspace(-8, 0, 5))
plt.tick_params(axis = 'both', labelsize = tick_fontsize.get_size())
if tight_layout_arg is not None:
    plt.subplots_adjust(**tight_layout_arg)
plt.legend(fontsize = legned_fontsize.get_size(), loc = 'lower right', ncol = 2)

In [None]:
fig.show()