In [None]:
import cupy as cp
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

from core import RNN, MSELoss

In [None]:
# G is not meters or kgs, more like units of system
def get_accelerations(pos, masses, G=1.0, eps=1e-5):
    acc = cp.zeros_like(pos) # hold acc for x and y for each objects
    for main_index in range(pos.shape[0]):
        for exam_index in range(pos.shape[0]):
            if main_index != exam_index:
                disp = pos[exam_index] - pos[main_index] # vector pointing from main to exam of shape (2,)
                dist = cp.sqrt(cp.sum(disp ** 2) + eps ** 2) # scalar ||r|| in 18.06
                r_hat = disp / dist # we do r / ||r|| to get unit vector
                acc[main_index] += G  * masses[exam_index] * r_hat / (dist ** 2) # of shape (2,)
                # We accumulate the pull
                # acc[main_index] = G * masses[main_index] * masses[exam_index] / dist ** 2 / masses[main_index] * r_hat (masses[main_index] cancel out)
                # acc[main_index] = G * masses[exam_index] / dist ** 2 * r_hat
                
    return acc

# Displacement vector is simply pos_j - pos_i: vecotr pointing from i to j
# Magnitude (r) is sqrt((x2-x1)^2 + (y2-y1)^2)
# To get F from G*m1*m2 / r^2 we need: G * m1 * m2 / |r21|^2 * r_hat where r_hat is the unit direction vector
# a = F/m

In [None]:
def prepare_gravity_data(data, window_size): # data is (steps, features)
    num_samples = data.shape[0] - window_size - 1 # -1 bc fencepost
    
    X = cp.zeros((num_samples, window_size, data.shape[1])) # (num_samples, window_size, 6)
    Y = cp.zeros((num_samples, window_size, data.shape[1])) # (num_samples, 6)
    
    for i in range(num_samples):
        X[i] = data[i : i + window_size]
        
        Y[i] = data[i + 1 : i + window_size + 1] - data[i : i + window_size]
    
    return X, Y

In [None]:
masses = cp.array([15.0, 2.0, 0.5])

pos = cp.array([
    [0.0, 0.0],    # Sun
    [3.0, 0.0],    # Inner Planet
    [0.0, 8.0]     # Outer Moon
])

vel = cp.array([
    [0.0, -0.2],   # Sun
    [0.0, 2.2],    # Inner Planet
    [-1.4, 0.0]    # Outer Moon
])

dt = 0.01
steps = 30000

history = []
for i in range(steps):
    acc = get_accelerations(pos, masses)
    vel += acc * dt
    pos += vel * dt
    
    history.append(pos.reshape(-1).copy())
    
data = cp.stack(history) # (steps, 6)

In [None]:
input_dim = len(masses) * 2 # Two coordinates per mass
hidden_dim = 128
output_dim = len(masses) * 2

model = RNN(input_dim, hidden_dim, output_dim)
loss_fn = MSELoss()

scale_factor = cp.max(cp.abs(data))
data_normalized = data / scale_factor
print(f"Scale factor: {scale_factor}")
window_size=20

X_train_abs, Y_train_deltas = prepare_gravity_data(data, window_size)
delta_scale = cp.max(cp.abs(Y_train_deltas))
Y_train_deltas /= delta_scale

epochs = 700
batch_size = 64
learning_rate = 5e-5

num_batches = X_train_abs.shape[0] // batch_size
X_train_abs = X_train_abs[:num_batches * batch_size]

loss_history = []

for epoch in range(epochs):
    indices = cp.random.permutation(X_train_abs.shape[0])
    
    for step in range(num_batches):
        batch_idx = indices[step : step + batch_size]
        x_batch = X_train_abs[batch_idx] # (batch_size, window_size, 6)
        y_batch = Y_train_deltas[batch_idx] # (batch_size, window_size, 6)
        
        y_pred, _ = model.forward(x_batch, h_prev=None) # ypred is same shape as x_batch
        loss = loss_fn.forward(y_pred, y_batch)
        
        loss_history.append(loss)
        
        dlogits = loss_fn.backward()
        dinputs = model.backward(dlogits)
        
        model.step(learning_rate)
        
    print(f"Epoch: {epoch} | Loss: {loss}")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.plot(cp.stack(loss_history).get())
ax.set_yscale('log')
plt.show()

In [None]:
def sample(model: RNN, boot, window_size, num_iterations):
    # boot is (window_size, 6)
    current_state = boot.reshape(1, window_size, -1)
    
    trajectory = []
    
    for i in range(num_iterations):
        pred_delta_norm, _ = model.forward(current_state, h_prev=None)
        
        last_delta_norm = pred_delta_norm[:, -1, :] # (1, 6)
        delta = last_delta_norm * delta_scale
        
        current_pos = current_state[:, -1, :] # (1, 6)
        next_pos = current_pos + delta
        trajectory.append(next_pos[0])
        
        current_state = cp.concatenate([current_state[:, 1:, :], next_pos.reshape(1, 1, -1)], axis=1)
        
    return cp.stack(trajectory)

In [None]:
start_idx = window_size
num_dream_steps = steps - start_idx
s = float(scale_factor)

seed_window = data_normalized[:start_idx] 
rnn_normalized_trajectory = sample(model, seed_window, window_size, num_dream_steps)

rnn_dream = cp.concatenate([seed_window, rnn_normalized_trajectory], axis=0).get() * s # (steps, 6)
real_data = data_normalized.get() * s

total_frames = steps - 1

In [None]:
max_dream = np.max(np.abs(rnn_dream))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

ax1.set_xlim(-s, s) # So we dont have a massive sim)
ax1.set_ylim(-s, s)
ax1.set_aspect('equal')
ax1.grid(True, alpha=0.3)
ax1.set_title("Real Physics")

ax2.set_xlim(-max_dream, max_dream) # So we dont have a massive sim)
ax2.set_ylim(-max_dream, max_dream)
ax2.set_aspect('equal')
ax2.grid(True, alpha=0.3)
ax2.set_title("RNN Dream")

dots_real = [ax1.plot([], [], 'ro')[0], ax1.plot([], [], 'go')[0], ax1.plot([], [], 'bo')[0]] # List[Line2D] smth I practiced already
trails_real =  [ax1.plot([], [], 'r-', alpha=0.2)[0], ax1.plot([], [], 'g-', alpha=0.2)[0], ax1.plot([], [], 'b-', alpha=0.2)[0]]

dots_rnn = [ax2.plot([], [], 'ro')[0], ax2.plot([], [], 'go')[0], ax2.plot([], [], 'bo')[0]]
trails_rnn =  [ax2.plot([], [], 'r-', alpha=0.2)[0], ax2.plot([], [], 'g-', alpha=0.2)[0], ax2.plot([], [], 'b-', alpha=0.2)[0]]

max_rnn_frames = rnn_dream.shape[0]

def update(i):
    if i % 100 == 0: print(f"Rendering Frame: {i}")
    plot_objects = []
    trail_len = 1000 
    
    for j in range(3):
        ix, iy = j*2, j*2+1
        
        x_r, y_r = real_data[i, ix], real_data[i, iy]
        dots_real[j].set_data([x_r], [y_r])
        
        r_start = max(0, i - trail_len)
        trails_real[j].set_data(real_data[r_start:i, ix], real_data[r_start:i, iy])
        
        x_n, y_n = rnn_dream[i, ix], rnn_dream[i, iy]
        dots_rnn[j].set_data([x_n], [y_n])
        
        trails_rnn[j].set_data(rnn_dream[r_start:i, ix], rnn_dream[r_start:i, iy])
        
        plot_objects.extend([dots_real[j], trails_real[j], dots_rnn[j], trails_rnn[j]])
    
    return plot_objects

# frames is steps - 1 because its kinda like a fencepost problem, we have steps-1 changes
ani = FuncAnimation(fig, update, frames=total_frames, interval=2, blit=True)
plt.close()
HTML(ani.to_html5_video())