In [29]:
from datagen import *
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation
import json
from IPython.display import HTML
import torch
from kwisatzHaderach import *


In [30]:
def animate3d(scene):
    scene = json.loads(scene)
    mass_save = np.array(scene['masses'])
    type_save = scene['types']
    bh_num = np.sum(np.array([1 if t == 'black hole' else 0 for t in type_save]))
    pos_save = np.array([f['pos'] for f in scene['frames']])

    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)
    ax.set_zlim(-2, 2)
    
    
    
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')


    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    # set background color to black
    ax.set_facecolor('black')

    # set panel color to black
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False






    # get index of 2 biggest masses
    bh_index = np.argsort(mass_save)[::-1][:bh_num]

    masses_without_bh = np.delete(mass_save, bh_index)

    max_mass = np.max(masses_without_bh)
    min_mass = np.min(masses_without_bh)

    
    # Create the colormap (inferno)
    cmap = plt.cm.autumn
    colors = np.zeros((pos_save.shape[1], 4))


    for i in range(pos_save.shape[1]):
        if i in bh_index:
            colors[i] = [0, 0, 0, 1]
        else:
            colors[i] = cmap((mass_save[i] - min_mass) / (max_mass - min_mass))
    
    # shape of pos_save is (num_frames, num_particles, 3)
    scat = ax.scatter(pos_save[0, :, 0], pos_save[0, :, 1], pos_save[0, :, 2], c=colors)

    if mass_save is not None:
        masses = mass_save[:pos_save.shape[0]]
        masses = 10000 * np.sqrt(masses)
        masses[bh_index] = 10
        
        scat.set_sizes(masses)

    def update(i):
        # Update the positions of the particles
        scat._offsets3d = (pos_save[i, :, 0], pos_save[i, :, 1], pos_save[i, :, 2])




    ani = animation.FuncAnimation(fig, update, frames=pos_save.shape[0], interval=10)
    plt.close(fig)
    return HTML(ani.to_html5_video())



In [31]:
scene = json.loads(generate_scene_2gals())

In [32]:
len(scene['masses'])

502

In [33]:
def get_pos_vel_mass_types(scene):
    scene = json.loads(scene)
    mass_save = np.array(scene['masses'])
    type_save = scene['types']
    pos_save = np.array([f['pos'] for f in scene['frames']])
    vel_save = np.array([f['vel'] for f in scene['frames']])
    return pos_save, vel_save, mass_save, type_save

In [34]:
def model_create_sim(model, initial_positions, initial_velocities, masses, types, num_steps, device):
    model.eval()
    final_json = {
        'masses': masses.tolist(),
        'types': types,
        'frames': []
    }
    frames = []
    model = model.to(device)
    mass_tensor = torch.tensor(masses, dtype=torch.float32).unsqueeze(1).to(device)
    pos = torch.tensor(initial_positions, dtype=torch.float32).to(device)
    vel = torch.tensor(initial_velocities, dtype=torch.float32).to(device)
    print(mass_tensor.shape, pos.shape, vel.shape)
    frames.append({
        'frame': 0,
        'pos': initial_positions.tolist(),
        'vel': initial_velocities.tolist()
    })

    with torch.no_grad():
        for i in range(1, num_steps):
            pos_pred, vel_pred = model(pos, vel, mass_tensor)
            frames.append({
                'frame': int(i),
                'pos': [[float(num) for num in pos_i] for pos_i in pos_pred.cpu().detach().numpy().tolist()],
                'vel': [[float(num) for num in vel_i] for vel_i in vel_pred.cpu().detach().numpy().tolist()]
            })
            pos = pos_pred
            vel = vel_pred
    final_json['frames'] = frames
    
    return json.dumps(final_json)
    

    

In [38]:
model = KwisatzHaderach(activation=True, layer_channels=[64, 64, 32, 3])
model_weights_list = os.listdir('./models/')
model_weights_list.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
last_model = model_weights_list[-1]
print(last_model)
model.load_state_dict(torch.load(f'./models/{last_model}'))


model_13.pt


<All keys matched successfully>

In [50]:
pos, vel, mass, types = get_pos_vel_mass_types(json.dumps(scene))
intitial_positions = pos[20]
initial_velocities = vel[20]
print(intitial_positions.shape, initial_velocities.shape, mass.shape)
predicted_scene = model_create_sim(model, intitial_positions, initial_velocities, mass, types, 1001, 'cuda')

(502, 3) (502, 3) (502,)
torch.Size([502, 1]) torch.Size([502, 3]) torch.Size([502, 3])


In [51]:
masses = np.array(json.loads(predicted_scene)['masses'])
gt_pos = np.array([f['pos'] for f in json.loads(json.dumps(scene))['frames']])[20:]
pred_pos = np.array([f['pos'] for f in json.loads(predicted_scene)['frames']])[:-20]
types = np.array(json.loads(predicted_scene)['types'])

In [52]:
gt_pos.shape, pred_pos.shape, masses.shape, types.shape

((981, 502, 3), (981, 502, 3), (502,), (502,))

In [53]:
def euclidean_distance(a, b, epsilon=1e-9):
    return torch.sqrt(torch.sum((a - b)**2, dim=-1) + epsilon)

def euclidean_distance_linalg(a, b):
    return np.linalg.norm(a - b, axis=-1)

In [54]:
np.mean(euclidean_distance_linalg(gt_pos, pred_pos), axis=-1)

array([0.00000000e+00, 1.16990409e-04, 4.81608997e-04, 1.10547202e-03,
       1.98800512e-03, 3.12682893e-03, 4.51969630e-03, 6.16445877e-03,
       8.05894242e-03, 1.02009389e-02, 1.25879527e-02, 1.52174734e-02,
       1.80870835e-02, 2.11943881e-02, 2.45370028e-02, 2.81124604e-02,
       3.19182096e-02, 3.59515606e-02, 4.02097633e-02, 4.46899587e-02,
       4.93892224e-02, 5.43042760e-02, 5.94317471e-02, 6.47684697e-02,
       7.03115523e-02, 7.60584832e-02, 8.20067979e-02, 8.81541514e-02,
       9.44980655e-02, 1.01035794e-01, 1.07764584e-01, 1.14682171e-01,
       1.21786428e-01, 1.29075636e-01, 1.36548563e-01, 1.44204611e-01,
       1.52044307e-01, 1.60068947e-01, 1.68279919e-01, 1.76678494e-01,
       1.85265549e-01, 1.94041773e-01, 2.03007617e-01, 2.12163135e-01,
       2.21508008e-01, 2.31041491e-01, 2.40762382e-01, 2.50669120e-01,
       2.60759812e-01, 2.71032374e-01, 2.81484593e-01, 2.92114148e-01,
       3.02918768e-01, 3.13896264e-01, 3.25044442e-01, 3.36361131e-01,
      

In [56]:
def animate_ground_truth_vs_prediction(masses, types, gt_pos, pred_pos):
    # create a 2 animation side by side, do in 3d    
    fig = plt.figure(figsize=(16, 8))

    ax1 = fig.add_subplot(121, projection='3d')
    ax1.set_xlim(-2, 2)
    ax1.set_ylim(-2, 2)
    ax1.set_zlim(-2, 2)
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_zlabel('z')
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_zticks([])
    ax1.set_title('Ground Truth')


    ax2 = fig.add_subplot(122, projection='3d')
    ax2.set_xlim(-2, 2)
    ax2.set_ylim(-2, 2)
    ax2.set_zlim(-2, 2)
    ax2.set_xlabel('x')
    ax2.set_ylabel('y')
    ax2.set_zlabel('z')
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_zticks([])
    ax2.set_title('Prediction')

    # get index of 2 black holes
    types = np.array(types)
    bh_index = np.where(types == 'black hole')[0]

    masses_without_bh = np.delete(masses, bh_index)

    max_mass = np.max(masses_without_bh)
    min_mass = np.min(masses_without_bh)

    # Create the colormap (inferno)
    cmap = plt.cm.autumn
    colors = np.zeros((gt_pos.shape[1], 4))

    for i in range(gt_pos.shape[1]):
        if i in bh_index:
            colors[i] = [0, 0, 0, 1]
        else:
            colors[i] = cmap((masses[i] - min_mass) / (max_mass - min_mass))

    # shape of pos_save is (num_frames, num_particles, 3)
    scat1 = ax1.scatter(gt_pos[0, :, 0], gt_pos[0, :, 1], gt_pos[0, :, 2], c=colors)

    if masses is not None:
        masses = masses[:gt_pos.shape[0]]
        masses = 100 * np.sqrt(masses)
        masses[bh_index] = 10
        scat1.set_sizes(masses)

    scat2 = ax2.scatter(pred_pos[0, :, 0], pred_pos[0, :, 1], pred_pos[0, :, 2], c=colors)

    if masses is not None:
        scat2.set_sizes(masses)

    error_per_frame = np.linalg.norm(gt_pos - pred_pos, axis=-1)

    

    def update(i):
        # Update the positions of the particles
        scat1._offsets3d = (gt_pos[i, :, 0], gt_pos[i, :, 1], gt_pos[i, :, 2])
        scat2._offsets3d = (pred_pos[i, :, 0], pred_pos[i, :, 1], pred_pos[i, :, 2])


    ani = animation.FuncAnimation(fig, update, frames=gt_pos.shape[0], interval=10)
    plt.close(fig)
    return HTML(ani.to_html5_video())

In [57]:
animate_ground_truth_vs_prediction(masses, types, gt_pos, pred_pos)