In [2]:
from sim import NBodySimulation, generateDisk3Dv3
import numpy as np
import json
import glob
import os

In [4]:
def gen_params():
    return {
        'nbStars': int(np.random.randint(100, 500)),
        'radius': float(np.random.uniform(1.0, 2.0)),
        'Mass': float(np.random.uniform(1.0, 3.0)),
        'zOffsetMax': float(np.random.uniform(0, 0.5)),
        'gravityCst': 1.0,
        'distribution': 'hernquist',
        'offset': [float(np.random.uniform(-1, 1)), float(np.random.uniform(-1, 1)), float(np.random.uniform(-1, 1))],
        'initial_vel': [float(np.random.uniform(-0.1, 0.1)), float(np.random.uniform(-0.1, 0.1)), float(np.random.uniform(-0.1, 0.1))],
        'clockwise': int(np.random.choice([1, 0])),
        'angle': [float(np.random.uniform(-1, 1)*2*np.pi), float(np.random.uniform(-1, 1)*2*np.pi), float(np.random.uniform(-1, 1)*2*np.pi)]
    }

def generate_scene_2gals():
    params1 = gen_params()
    params2 = gen_params()
    print(params1)
    print(params2)

    particles1 = generateDisk3Dv3(**params1)
    particles2 = generateDisk3Dv3(**params2)

    t_end = 10.0
    dt = 0.01
    softening = 0.1
    G = 1.0

    particles = particles1 + particles2
    sim = NBodySimulation(particles, G, softening, dt)

    pos, vel, acc, KE, PE, _, masses, types = sim.run(t_end=t_end, save_states=True)
    
    # Convert all arrays to lists
    pos = np.array(pos).transpose(2, 0, 1)
    vel = np.array(vel).transpose(2, 0, 1)
    acc = np.array(acc).transpose(2, 0, 1)
    KE = KE.flatten().astype(float).tolist()  # Ensure floats
    PE = PE.flatten().astype(float).tolist()  # Ensure floats
    masses = masses.flatten().astype(float).tolist()  # Ensure floats


    frames = []
    for i in range(len(pos)):
        frames.append({
            'frame': int(i),  # Ensure the frame index is an int
            'pos': pos[i].tolist(),
            'vel': vel[i].tolist(),
            'acc': acc[i].tolist()
        })

    final_json = {
        'galaxy1_params': params1,
        'galaxy2_params': params2,
        'dt': float(dt),
        'softening': float(softening),
        'G': float(G),
        't_end': float(t_end),
        'masses': [float(m) for m in masses],
        'types': types,
        'KE': KE,
        'PE': PE,
        'frames': frames
    }

    return json.dumps(final_json, indent=4)


def generate_dataset(n_scenes=5, window_size=3, shuffle=True, dir='./train/', save=True):
    # Ensure the directory exists
    os.makedirs(dir, exist_ok=True)
    
    # the objective is to generate samples from n scenes, of 3 frames each, saving positions, velocities and accelerations, the idea is to predict acceleration from the first frame for the next 2 frames to integrate position and velocity
    other_files = glob.glob(dir + '*.json')
    # get the last id from the files, given the structure of the file names is train_0.json, train_1.json, etc
    last_id = -1
    for file in other_files:
        last_id = max(last_id, int(file.split('_')[-1].split('.')[0]))

    new_id = last_id + 1
    name = dir + f'{dir[1:-1]}_{new_id}.json'
    
    dataset = []
    for i in range(n_scenes):
        scene = generate_scene_2gals()
        scene = json.loads(scene)
        frames = scene['frames']
        masses = scene['masses']
        for j in range(len(frames)-window_size):
            sample = {
                'masses': masses,
                'pos': frames[j]['pos'],
                'vel': frames[j]['vel'],
            }
            for k in range(1, window_size):
                sample['pos_next{}'.format(k)] = frames[j+k]['pos']
                sample['vel_next{}'.format(k)] = frames[j+k]['vel']
            dataset.append(sample)
    if shuffle:
        np.random.shuffle(dataset)
    if save:
        with open(name, 'w') as f:
            json.dump(dataset, f, indent=4)

    return dataset





In [5]:
dataset = generate_dataset(5, 3, dir='./val/', save=True, shuffle=False)

{'nbStars': 261, 'radius': 1.70483688036224, 'Mass': 2.4440477531394436, 'zOffsetMax': 0.28610406287310297, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.1858981787706664, -0.24959811617877792, 0.011070678860936933], 'initial_vel': [-0.09403734386277066, -0.0772198612275312, 0.06929596506416486], 'clockwise': 0, 'angle': [-2.3779386181428652, -1.5478815961027756, 4.3183322883283495]}
{'nbStars': 258, 'radius': 1.4669600765537436, 'Mass': 1.978008678633857, 'zOffsetMax': 0.2837449092751201, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.5173191106063031, -0.574694082169575, -0.2273263238240577], 'initial_vel': [-0.011910666199616185, 0.01722503376958255, 0.011167831479451246], 'clockwise': 0, 'angle': [1.3129815663268316, -5.57037041367148, 3.8837849322014346]}


100%|██████████| 1000/1000 [00:26<00:00, 37.38it/s]


{'nbStars': 425, 'radius': 1.2825177845328057, 'Mass': 2.5839875444322953, 'zOffsetMax': 0.258941972101402, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.5402505911546582, -0.41691917272405044, 0.8643631881378888], 'initial_vel': [-0.0332591398885443, -0.0484747853937326, 0.021412462998334214], 'clockwise': 0, 'angle': [-2.267617488789109, -2.0118349570921645, -2.3253149006748846]}
{'nbStars': 413, 'radius': 1.236021401270027, 'Mass': 1.6841150440566959, 'zOffsetMax': 0.2424963795589563, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.9208004670075041, 0.7497195827407674, 0.6114145918269405], 'initial_vel': [0.06534662850584391, 0.005385561619011131, 0.057715821069913215], 'clockwise': 1, 'angle': [5.966980468181682, 5.643651160491204, -0.453879636336316]}


100%|██████████| 1000/1000 [01:07<00:00, 14.85it/s]


{'nbStars': 208, 'radius': 1.7102454512804628, 'Mass': 2.1723231360596116, 'zOffsetMax': 0.2542760871031119, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.04642475516861855, -0.8039039777507349, -0.47154625530596594], 'initial_vel': [-0.03017353550640077, 0.015110609025400665, -0.07970051495340913], 'clockwise': 1, 'angle': [-5.6974894780789525, -3.021598372474945, -4.934907710690914]}
{'nbStars': 483, 'radius': 1.4167423411980336, 'Mass': 2.1064072939911735, 'zOffsetMax': 0.05133285727850506, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.3550740598665274, 0.6157074592995915, 0.04469643908171306], 'initial_vel': [0.05479719723539822, -0.09169709384472875, -0.024988531310201845], 'clockwise': 1, 'angle': [5.446487435253293, 3.1069588648132784, 3.0146750447435413]}


100%|██████████| 1000/1000 [00:29<00:00, 34.20it/s]


{'nbStars': 242, 'radius': 1.5374199363928063, 'Mass': 1.646361560176956, 'zOffsetMax': 0.13119425755143133, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.08874935920959337, 0.5693984281373903, 0.13300242908506266], 'initial_vel': [0.08765454379085988, -0.01828015591917935, 0.041044633100235595], 'clockwise': 0, 'angle': [6.049622996334708, 0.8249558730999722, -4.811199508987864]}
{'nbStars': 456, 'radius': 1.2108140910879637, 'Mass': 1.927146354454598, 'zOffsetMax': 0.19024847455030475, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.7577375689146286, 0.02577142847526659, -0.10295487549162918], 'initial_vel': [-0.04197444396087775, 0.06075032579814352, -0.08586931510140002], 'clockwise': 1, 'angle': [2.030353462093923, 1.4055730650744276, 6.123273216409077]}


100%|██████████| 1000/1000 [00:29<00:00, 34.36it/s]


{'nbStars': 451, 'radius': 1.7903746579286106, 'Mass': 1.6394494203815877, 'zOffsetMax': 0.3062838695703758, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.42223791678172873, 0.1837300337734411, -0.3830986348697232], 'initial_vel': [-0.06499259513999991, 0.0419098788134355, 0.07354205024876795], 'clockwise': 1, 'angle': [3.3540232537731196, -6.207005462227422, -4.74720086903305]}
{'nbStars': 432, 'radius': 1.2490111350414148, 'Mass': 1.9569082034797671, 'zOffsetMax': 0.11721922057350603, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.9855673583032385, -0.6619576800338758, -0.6644620296050887], 'initial_vel': [0.06750288624943937, 0.07465872787880043, -0.08725992135266217], 'clockwise': 0, 'angle': [0.9769286143399672, -3.6009984969785758, -0.4529366731994696]}


100%|██████████| 1000/1000 [00:46<00:00, 21.64it/s]
