In [1]:
from datagen import *
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.animation as animation
import json
from IPython.display import HTML
import torch
from kwisatzHaderach import *


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def animate3d(scene):
    scene = json.loads(scene)
    mass_save = np.array(scene['masses'])
    type_save = scene['types']
    bh_num = np.sum(np.array([1 if t == 'black hole' else 0 for t in type_save]))
    pos_save = np.array([f['pos'] for f in scene['frames']])

    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    ax.set_xlim(-2, 2)
    ax.set_ylim(-2, 2)
    ax.set_zlim(-2, 2)
    
    
    
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')


    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    # set background color to black
    ax.set_facecolor('black')

    # set panel color to black
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False






    # get index of 2 biggest masses
    bh_index = np.argsort(mass_save)[::-1][:bh_num]

    masses_without_bh = np.delete(mass_save, bh_index)

    max_mass = np.max(masses_without_bh)
    min_mass = np.min(masses_without_bh)

    
    # Create the colormap (inferno)
    cmap = plt.cm.autumn
    colors = np.zeros((pos_save.shape[1], 4))


    for i in range(pos_save.shape[1]):
        if i in bh_index:
            colors[i] = [0, 0, 0, 1]
        else:
            colors[i] = cmap((mass_save[i] - min_mass) / (max_mass - min_mass))
    
    # shape of pos_save is (num_frames, num_particles, 3)
    scat = ax.scatter(pos_save[0, :, 0], pos_save[0, :, 1], pos_save[0, :, 2], c=colors)

    if mass_save is not None:
        masses = mass_save[:pos_save.shape[0]]
        masses = 10000 * np.sqrt(masses)
        masses[bh_index] = 10
        
        scat.set_sizes(masses)

    def update(i):
        # Update the positions of the particles
        scat._offsets3d = (pos_save[i, :, 0], pos_save[i, :, 1], pos_save[i, :, 2])




    ani = animation.FuncAnimation(fig, update, frames=pos_save.shape[0], interval=10)
    plt.close(fig)
    return HTML(ani.to_html5_video())



In [3]:
scene = json.loads(generate_scene_2gals())

In [4]:
len(scene['masses'])

502

In [5]:
def get_pos_vel_mass_types_bh_index(scene):
    scene = json.loads(scene)
    mass_save = np.array(scene['masses'])
    type_save = np.array(scene['types'])
    bh_index = np.where(type_save == 'black hole')[0]
    pos_save = np.array([f['pos'] for f in scene['frames']])
    vel_save = np.array([f['vel'] for f in scene['frames']])
    bh_pos = pos_save[:, bh_index]
    bh_vel = vel_save[:, bh_index]
    bh_mass = mass_save[bh_index]   
    return pos_save, vel_save, mass_save, type_save, bh_pos, bh_vel, bh_mass, bh_index

In [6]:
def get_new_pos_vel(acc, pos, vel, dt=0.01):
    new_vel = vel + acc * dt
    new_pos = pos + new_vel * dt
    return new_pos, new_vel

In [7]:
def model_create_sim(model, initial_positions, initial_velocities, masses, types, num_steps, device):
    model.eval()
    final_json = {
        'masses': masses.tolist(),
        'types': types,
        'frames': []
    }
    frames = []
    model = model.to(device)
    mass_tensor = torch.tensor(masses, dtype=torch.float32).unsqueeze(1).to(device)
    pos = torch.tensor(initial_positions, dtype=torch.float32).to(device)
    vel = torch.tensor(initial_velocities, dtype=torch.float32).to(device)

    
    print(mass_tensor.shape, pos.shape, vel.shape)
    frames.append({
        'frame': 0,
        'pos': initial_positions.tolist(),
        'vel': initial_velocities.tolist()
    })

    with torch.no_grad():
        for i in range(1, num_steps):
            acc = model(pos, vel, mass_tensor)
            pos_pred, vel_pred = get_new_pos_vel(acc, pos, vel)
            frames.append({
                'frame': int(i),
                'pos': [[float(num) for num in pos_i] for pos_i in pos_pred.cpu().detach().numpy().tolist()],
                'vel': [[float(num) for num in vel_i] for vel_i in vel_pred.cpu().detach().numpy().tolist()]
            })
            pos = pos_pred
            vel = vel_pred
    final_json['frames'] = frames
    
    return json.dumps(final_json)
    
def model_create_sim_bh(model, initial_positions, initial_velocities, masses, types, num_steps, device, bh_index):
    model.eval()
    final_json = {
        'masses': masses.tolist(),
        'types': types,
        'frames': []
    }
    frames = []
    bh_index = torch.tensor(bh_index, dtype=torch.long).to(device)
    model = model.to(device)
    mass_tensor = torch.tensor(masses, dtype=torch.float32).unsqueeze(1).to(device)
    pos = torch.tensor(initial_positions, dtype=torch.float32).to(device)
    vel = torch.tensor(initial_velocities, dtype=torch.float32).to(device)
    pos_bh = pos[bh_index].to(device)
    vel_bh = vel[bh_index].to(device)
    mass_bh = mass_tensor[bh_index].to(device)
    
    frames.append({
        'frame': 0,
        'pos': initial_positions.tolist(),
        'vel': initial_velocities.tolist()
    })

    with torch.no_grad():
        for i in range(1, num_steps):
            acc = model(pos, vel, mass_tensor, pos_bh, vel_bh, mass_bh)
            pos_pred, vel_pred = get_new_pos_vel(acc, pos, vel)
            frames.append({
                'frame': int(i),
                'pos': [[float(num) for num in pos_i] for pos_i in pos_pred.cpu().detach().numpy().tolist()],
                'vel': [[float(num) for num in vel_i] for vel_i in vel_pred.cpu().detach().numpy().tolist()]
            })
            pos = pos_pred
            vel = vel_pred
            pos_bh = pos[bh_index]
            vel_bh = vel[bh_index]
    final_json['frames'] = frames
    
    return json.dumps(final_json)

    

In [9]:
model = KwisatzHaderach(activation=True, layer_channels=[128, 128, 64, 64, 3])
model_weights_list = os.listdir('./models/')
model_weights_list.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
last_model = model_weights_list[-1]
print(last_model)
model.load_state_dict(torch.load(f'./models/{last_model}', map_location=torch.device('cpu')))


model_18.pt


<All keys matched successfully>

In [10]:
pos, vel, mass, types, _, _, _ = get_pos_vel_mass_types_bh_index(json.dumps(scene))
intitial_positions = pos[20]
initial_velocities = vel[20]
print(intitial_positions.shape, initial_velocities.shape, mass.shape)
predicted_scene = model_create_sim(model, intitial_positions, initial_velocities, mass, types, 1001, 'cpu')

(502, 3) (502, 3) (502,)
torch.Size([502, 1]) torch.Size([502, 3]) torch.Size([502, 3])


In [11]:
masses = np.array(json.loads(predicted_scene)['masses'])
gt_pos = np.array([f['pos'] for f in json.loads(json.dumps(scene))['frames']])[20:]
pred_pos = np.array([f['pos'] for f in json.loads(predicted_scene)['frames']])[:-20]
types = np.array(json.loads(predicted_scene)['types'])

In [12]:
gt_pos.shape, pred_pos.shape, masses.shape, types.shape

((981, 502, 3), (981, 502, 3), (502,), (502,))

In [13]:
def euclidean_distance(a, b, epsilon=1e-9):
    return torch.sqrt(torch.sum((a - b)**2, dim=-1) + epsilon)

def euclidean_distance_linalg(a, b):
    return np.linalg.norm(a - b, axis=-1)

In [14]:
np.mean(euclidean_distance_linalg(gt_pos, pred_pos), axis=-1)

array([0.00000000e+00, 1.24785297e-04, 5.08033122e-04, 1.14820082e-03,
       2.04315636e-03, 3.19076841e-03, 4.58889956e-03, 6.23541494e-03,
       8.12817564e-03, 1.02650491e-02, 1.26439074e-02, 1.52626264e-02,
       1.81190894e-02, 2.12111889e-02, 2.45368226e-02, 2.80938983e-02,
       3.18803367e-02, 3.58940660e-02, 4.01330273e-02, 4.45951696e-02,
       4.92784487e-02, 5.41808406e-02, 5.93003231e-02, 6.46348909e-02,
       7.01825383e-02, 7.59412789e-02, 8.19091308e-02, 8.80841196e-02,
       9.44642794e-02, 1.01047651e-01, 1.07832279e-01, 1.14816216e-01,
       1.21997517e-01, 1.29374239e-01, 1.36944441e-01, 1.44706175e-01,
       1.52657493e-01, 1.60796440e-01, 1.69121044e-01, 1.77629321e-01,
       1.86319254e-01, 1.95188794e-01, 2.04235846e-01, 2.13458234e-01,
       2.22853695e-01, 2.32419828e-01, 2.42154068e-01, 2.52053619e-01,
       2.62115422e-01, 2.72336125e-01, 2.82712092e-01, 2.93239482e-01,
       3.03914339e-01, 3.14732712e-01, 3.25690775e-01, 3.36784947e-01,
      

In [15]:
def animate_ground_truth_vs_prediction(masses, types, gt_pos, pred_pos):
    # create a 2 animation side by side, do in 3d    
    fig = plt.figure(figsize=(16, 8))

    ax1 = fig.add_subplot(121, projection='3d')
    ax1.set_xlim(-2, 2)
    ax1.set_ylim(-2, 2)
    ax1.set_zlim(-2, 2)
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_zlabel('z')
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_zticks([])
    ax1.set_title('Ground Truth')


    ax2 = fig.add_subplot(122, projection='3d')
    ax2.set_xlim(-2, 2)
    ax2.set_ylim(-2, 2)
    ax2.set_zlim(-2, 2)
    ax2.set_xlabel('x')
    ax2.set_ylabel('y')
    ax2.set_zlabel('z')
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_zticks([])
    ax2.set_title('Prediction')

    # get index of 2 black holes
    types = np.array(types)
    bh_index = np.where(types == 'black hole')[0]

    masses_without_bh = np.delete(masses, bh_index)

    max_mass = np.max(masses_without_bh)
    min_mass = np.min(masses_without_bh)

    # Create the colormap (inferno)
    cmap = plt.cm.autumn
    colors = np.zeros((gt_pos.shape[1], 4))

    for i in range(gt_pos.shape[1]):
        if i in bh_index:
            colors[i] = [0, 0, 0, 1]
        else:
            colors[i] = cmap((masses[i] - min_mass) / (max_mass - min_mass))

    # shape of pos_save is (num_frames, num_particles, 3)
    scat1 = ax1.scatter(gt_pos[0, :, 0], gt_pos[0, :, 1], gt_pos[0, :, 2], c=colors)

    if masses is not None:
        masses = masses[:gt_pos.shape[0]]
        masses = 100 * np.sqrt(masses)
        masses[bh_index] = 10
        scat1.set_sizes(masses)

    scat2 = ax2.scatter(pred_pos[0, :, 0], pred_pos[0, :, 1], pred_pos[0, :, 2], c=colors)

    if masses is not None:
        scat2.set_sizes(masses)

    error_per_frame = np.linalg.norm(gt_pos - pred_pos, axis=-1)

    

    def update(i):
        # Update the positions of the particles
        scat1._offsets3d = (gt_pos[i, :, 0], gt_pos[i, :, 1], gt_pos[i, :, 2])
        scat2._offsets3d = (pred_pos[i, :, 0], pred_pos[i, :, 1], pred_pos[i, :, 2])


    ani = animation.FuncAnimation(fig, update, frames=gt_pos.shape[0], interval=10)
    plt.close(fig)
    return HTML(ani.to_html5_video())

In [16]:
animate_ground_truth_vs_prediction(masses, types, gt_pos, pred_pos)