In [1]:
from sim import NBodySimulation, generateDisk3Dv3
import numpy as np
import json
import glob
import os

In [3]:
def gen_params():
    return {
        'nbStars': int(np.random.randint(100, 500)),
        'radius': float(np.random.uniform(1.0, 2.0)),
        'Mass': float(np.random.uniform(1.0, 3.0)),
        'zOffsetMax': float(np.random.uniform(0, 0.5)),
        'gravityCst': 1.0,
        'distribution': 'hernquist',
        'offset': [float(np.random.uniform(-1, 1)), float(np.random.uniform(-1, 1)), float(np.random.uniform(-1, 1))],
        'initial_vel': [float(np.random.uniform(-0.1, 0.1)), float(np.random.uniform(-0.1, 0.1)), float(np.random.uniform(-0.1, 0.1))],
        'clockwise': int(np.random.choice([1, 0])),
        'angle': [float(np.random.uniform(-1, 1)*2*np.pi), float(np.random.uniform(-1, 1)*2*np.pi), float(np.random.uniform(-1, 1)*2*np.pi)]
    }

def generate_scene_2gals():
    params1 = gen_params()
    params2 = gen_params()
    print(params1)
    print(params2)

    particles1 = generateDisk3Dv3(**params1)
    particles2 = generateDisk3Dv3(**params2)

    t_end = 10.0
    dt = 0.01
    softening = 0.1
    G = 1.0

    particles = particles1 + particles2
    sim = NBodySimulation(particles, G, softening, dt)

    pos, vel, acc, KE, PE, _, masses, types = sim.run(t_end=t_end, save_states=True)
    
    # Convert all arrays to lists
    pos = np.array(pos).transpose(2, 0, 1)
    vel = np.array(vel).transpose(2, 0, 1)
    acc = np.array(acc).transpose(2, 0, 1)
    KE = KE.flatten().astype(float).tolist()  # Ensure floats
    PE = PE.flatten().astype(float).tolist()  # Ensure floats
    masses = masses.flatten().astype(float).tolist()  # Ensure floats


    frames = []
    for i in range(len(pos)):
        frames.append({
            'frame': int(i),  # Ensure the frame index is an int
            'pos': pos[i].tolist(),
            'vel': vel[i].tolist(),
            'acc': acc[i].tolist()
        })

    final_json = {
        'galaxy1_params': params1,
        'galaxy2_params': params2,
        'dt': float(dt),
        'softening': float(softening),
        'G': float(G),
        't_end': float(t_end),
        'masses': [float(m) for m in masses],
        'types': types,
        'KE': KE,
        'PE': PE,
        'frames': frames
    }

    return json.dumps(final_json, indent=4)


def generate_dataset(n_scenes=5, window_size=3, shuffle=True, dir='./train/', save=True):
    # Ensure the directory exists
    os.makedirs(dir, exist_ok=True)
    
    # the objective is to generate samples from n scenes, of 3 frames each, saving positions, velocities and accelerations, the idea is to predict acceleration from the first frame for the next 2 frames to integrate position and velocity
    other_files = glob.glob(dir + '*.json')
    # get the last id from the files, given the structure of the file names is train_0.json, train_1.json, etc
    last_id = -1
    for file in other_files:
        last_id = max(last_id, int(file.split('_')[-1].split('.')[0]))

    new_id = last_id + 1
    name = dir + f'{dir[1:-1]}_{new_id}.json'
    
    dataset = []
    for i in range(n_scenes):
        scene = generate_scene_2gals()
        scene = json.loads(scene)
        frames = scene['frames']
        masses = scene['masses']
        for j in range(len(frames)-window_size):
            sample = {
                'masses': masses,
                'pos': frames[j]['pos'],
                'vel': frames[j]['vel'],
            }
            for k in range(1, window_size):
                sample['pos_next{}'.format(k)] = frames[j+k]['pos']
                sample['vel_next{}'.format(k)] = frames[j+k]['vel']
            dataset.append(sample)
    if shuffle:
        np.random.shuffle(dataset)
    if save:
        with open(name, 'w') as f:
            json.dump(dataset, f, indent=4)

    return dataset

for i in range(10):
    dataset = generate_dataset(5, 3)



{'nbStars': 245, 'radius': 1.6942184313501, 'Mass': 2.1784798707192525, 'zOffsetMax': 0.44736016191542327, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.2876241169817144, -0.22791497866547017, -0.36407795834427126], 'initial_vel': [-0.043719493903054324, 0.05080910553526563, 0.07364263165268234], 'clockwise': 1, 'angle': [4.774561292409996, 5.865086473024548, -2.57797767491116]}
{'nbStars': 397, 'radius': 1.7619680306662144, 'Mass': 2.8174941010279726, 'zOffsetMax': 0.2331961322389859, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.8707549607075169, -0.45976693595036733, -0.10238701682738305], 'initial_vel': [0.03134423233762826, 0.018569483036029205, 0.08901015455161224], 'clockwise': 0, 'angle': [-4.31470343734244, 4.9468465851709045, -2.73070700620388]}


100%|██████████| 1000/1000 [00:26<00:00, 37.41it/s]


{'nbStars': 302, 'radius': 1.9258151231844929, 'Mass': 1.6410181549754754, 'zOffsetMax': 0.41292860367103856, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.38097271585076564, 0.6569928575757358, 0.5347573344383243], 'initial_vel': [0.0060028556023672686, 0.0272400346058998, 0.024064880014372098], 'clockwise': 0, 'angle': [1.8387476851508051, -1.821637186628475, 1.8147641443518574]}
{'nbStars': 427, 'radius': 1.476059788780784, 'Mass': 2.666879916965729, 'zOffsetMax': 0.18352386013567618, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.6898698344100254, -0.8434914793770087, -0.45790713735660415], 'initial_vel': [-0.010369445551267398, -0.05617086338250104, 0.004601578380386523], 'clockwise': 1, 'angle': [-1.1271806078690498, 1.1576487646646567, 4.1467219321350655]}


100%|██████████| 1000/1000 [00:33<00:00, 30.07it/s]


{'nbStars': 381, 'radius': 1.7825678400891267, 'Mass': 2.913872901175841, 'zOffsetMax': 0.1550166197826024, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.9167605003968098, 0.5594703945964572, 0.5931946446708942], 'initial_vel': [0.07566218192326943, -0.062467813231344455, -0.07171895020702011], 'clockwise': 0, 'angle': [0.025037134051446346, -1.5581270046079683, -6.157742664623349]}
{'nbStars': 400, 'radius': 1.217348215863553, 'Mass': 2.5476326996814134, 'zOffsetMax': 0.03114097558626222, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.5888804105772383, 0.46991251974831827, -0.6455090956880225], 'initial_vel': [-0.014060964045966129, 0.04461680651883218, -0.007901299473592222], 'clockwise': 1, 'angle': [-0.8516649664397095, 5.755295328945812, -4.073361165483509]}


100%|██████████| 1000/1000 [00:37<00:00, 26.88it/s]


{'nbStars': 381, 'radius': 1.369512451955912, 'Mass': 1.432281286362356, 'zOffsetMax': 0.17244457941974933, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.09035417992441097, 0.15033787292734546, 0.48904486097331823], 'initial_vel': [-0.07037743701702344, 0.023109603112846155, 0.04779568770585835], 'clockwise': 1, 'angle': [4.0506118239591355, -2.8548715822712665, 4.858546966507633]}
{'nbStars': 450, 'radius': 1.788509960607814, 'Mass': 2.9743175346901705, 'zOffsetMax': 0.3437089460835134, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.39583346074016124, 0.376543895400127, 0.9250446345994969], 'initial_vel': [0.0009997245733592564, -0.06707053513848567, 0.042836018220159006], 'clockwise': 0, 'angle': [-5.340830938796844, -5.005180638285344, 1.3191494210444832]}


100%|██████████| 1000/1000 [00:41<00:00, 23.87it/s]


{'nbStars': 340, 'radius': 1.4309594116338986, 'Mass': 2.9368160443314992, 'zOffsetMax': 0.07231439850180649, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.222128832680478, 0.6429240028427217, -0.9197683207086957], 'initial_vel': [0.07457479314754734, 0.08280371657090865, -0.041110852806733926], 'clockwise': 0, 'angle': [-2.3431537402237828, -6.056520088395501, -6.021024260842825]}
{'nbStars': 439, 'radius': 1.340689626235466, 'Mass': 2.239262257752432, 'zOffsetMax': 0.13232514363290016, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.4701791519140559, -0.0421893920731955, -0.5896038280668701], 'initial_vel': [-0.07397491596981542, -0.0636368049599221, 0.02858172220882635], 'clockwise': 0, 'angle': [3.965365366143832, -3.9947525606101872, 0.22137242272933855]}


100%|██████████| 1000/1000 [00:37<00:00, 26.38it/s]


{'nbStars': 490, 'radius': 1.4274140112017606, 'Mass': 2.9575449628593375, 'zOffsetMax': 0.026100875974956306, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.3765584227987415, 0.7277137831785887, -0.6830777099461713], 'initial_vel': [-0.013758277508799405, -0.06157320599175034, -0.07436259668872454], 'clockwise': 0, 'angle': [2.3265631463276857, 5.356761662264787, -4.50270719773833]}
{'nbStars': 234, 'radius': 1.8261588586316342, 'Mass': 2.507290191842686, 'zOffsetMax': 0.3382304367239208, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.9006681091241933, -0.027477011650471672, 0.5437082088423824], 'initial_vel': [-0.04940920291813871, 0.01140444747389395, -0.044701791331678536], 'clockwise': 1, 'angle': [-3.4549254331054686, 2.917462372917019, 3.7263963257354504]}


100%|██████████| 1000/1000 [00:33<00:00, 29.96it/s]


{'nbStars': 407, 'radius': 1.4670940736809246, 'Mass': 1.517335901975853, 'zOffsetMax': 0.364793377521107, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.18297815874442502, -0.31435307448176864, 0.9658184286203169], 'initial_vel': [-0.004206124528289695, -0.08206218125598064, 0.0631912976662348], 'clockwise': 1, 'angle': [2.4206526988129884, -4.087433120690354, -4.125734965905525]}
{'nbStars': 310, 'radius': 1.0170516596603103, 'Mass': 2.9190198198500195, 'zOffsetMax': 0.07575207497743047, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.6529205587888345, 0.8675332883418763, 0.39576149356028845], 'initial_vel': [0.040047078966868155, 0.07659711726158011, -0.019384956579504703], 'clockwise': 1, 'angle': [-3.4750801329722707, -4.500060143635011, -0.5307614675195582]}


100%|██████████| 1000/1000 [00:31<00:00, 32.00it/s]


{'nbStars': 120, 'radius': 1.8386850634187475, 'Mass': 2.555230565313377, 'zOffsetMax': 0.49005182726270974, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.7663217446703288, 0.36769544467458304, -0.35471878206005836], 'initial_vel': [-0.006237791306330398, 0.05481960014230086, -0.003946552817554544], 'clockwise': 0, 'angle': [-0.16093168361370655, -1.893767686260694, -4.2303675568180585]}
{'nbStars': 327, 'radius': 1.2246012723182975, 'Mass': 1.588356533594199, 'zOffsetMax': 0.3534712336700844, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.07825610955286688, -0.8663062830811312, 0.8081412081747124], 'initial_vel': [-0.016180496592217694, 0.025792032141168347, -0.001514068055299439], 'clockwise': 1, 'angle': [-4.0572461164144515, -1.1773034986929436, -5.357892925577315]}


100%|██████████| 1000/1000 [00:14<00:00, 70.59it/s]


{'nbStars': 264, 'radius': 1.5940560980447724, 'Mass': 2.684858889307982, 'zOffsetMax': 0.2811887726276586, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.23047000761577396, 0.7705602471914983, -0.943162434288483], 'initial_vel': [-0.05104937546791828, -0.02184433774691144, -0.03144583601144606], 'clockwise': 0, 'angle': [-1.0242559873923467, -4.709602966833427, 4.065339903368884]}
{'nbStars': 309, 'radius': 1.5165528020688763, 'Mass': 1.9633955055788501, 'zOffsetMax': 0.0610990064188866, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.5682759769762264, -0.4575998852794805, 0.820185168892819], 'initial_vel': [0.03878727636249196, -0.08604721455355285, 0.011188717339797427], 'clockwise': 1, 'angle': [-5.760662250423902, -1.656681198644978, -2.2596999718882387]}


100%|██████████| 1000/1000 [00:21<00:00, 45.78it/s]


{'nbStars': 193, 'radius': 1.6830678180378378, 'Mass': 2.653058360094854, 'zOffsetMax': 0.32636895235366886, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.4004598550112697, 0.09585330475528453, 0.5750707471436365], 'initial_vel': [-0.0419106528483135, -0.026180505467624224, -0.04930450782061644], 'clockwise': 0, 'angle': [-0.43153664951834875, 4.346268453256508, 0.5759641418587439]}
{'nbStars': 472, 'radius': 1.242893955359665, 'Mass': 2.333952892062291, 'zOffsetMax': 0.40352035382798673, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.2549066616398894, 0.5835588604270572, 0.8952307081599709], 'initial_vel': [-0.09489294091724697, -0.008586371625730657, 0.07205146543017446], 'clockwise': 1, 'angle': [5.00268835489883, 5.267847610155252, 3.673586985246653]}


100%|██████████| 1000/1000 [00:28<00:00, 35.19it/s]


{'nbStars': 412, 'radius': 1.6409714429262143, 'Mass': 1.942475056201624, 'zOffsetMax': 0.13598274881388572, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.8383654135187555, -0.9945299438986537, -0.2613944645979196], 'initial_vel': [-0.03459355426883126, -0.024881540134430224, -0.03933728439330897], 'clockwise': 0, 'angle': [-2.2870025477132954, -3.1240368524923405, -2.766150894575088]}
{'nbStars': 355, 'radius': 1.518288519485259, 'Mass': 2.5362094060385054, 'zOffsetMax': 0.38432232261628096, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.29900783049380597, -0.29054089876016853, -0.31952594797141187], 'initial_vel': [0.08083695776057961, 0.0032092676298167977, 0.07536389961385709], 'clockwise': 1, 'angle': [1.0878222602483447, -5.767895407868039, 0.39449040673611163]}


100%|██████████| 1000/1000 [00:36<00:00, 27.46it/s]


{'nbStars': 485, 'radius': 1.5978918479697906, 'Mass': 2.4186531848519697, 'zOffsetMax': 0.14354347854795713, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.8704583936004366, 0.2782049080318243, -0.5288129135203414], 'initial_vel': [-0.0025367231723941452, 0.04534392368643125, -0.08854574987963577], 'clockwise': 0, 'angle': [2.5451878957378264, -5.8437479746246614, -3.424179017828404]}
{'nbStars': 179, 'radius': 1.2135013044476064, 'Mass': 2.360101025654586, 'zOffsetMax': 0.23106677181086155, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [0.0899492957877599, -0.9989940307923104, 0.6023250716981043], 'initial_vel': [0.07532559241629933, 0.030793823902994816, -0.018289031064992245], 'clockwise': 1, 'angle': [-5.577298418002743, 1.8413535535833079, 4.468195921941863]}


 61%|██████    | 609/1000 [00:17<00:11, 33.96it/s]


KeyboardInterrupt: 

In [3]:
dataset = generate_dataset(1, 3, dir='./val/', save=True, shuffle=False)

{'nbStars': 191, 'radius': 1.510812342310428, 'Mass': 1.4600278559427262, 'zOffsetMax': 0.2931003589591229, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.4457418074649573, -0.10370057910820907, 0.29451858292172184], 'initial_vel': [0.004420953238919931, 0.0716105615406098, 0.011004092400341975], 'clockwise': 1, 'angle': [1.191850233287116, -0.6165623260710122, -2.6473726488142857]}
{'nbStars': 395, 'radius': 1.4358907785954926, 'Mass': 1.292527270351111, 'zOffsetMax': 0.08924899287798277, 'gravityCst': 1.0, 'distribution': 'hernquist', 'offset': [-0.19983954058510012, -0.46778475214981796, 0.3711069394127089], 'initial_vel': [0.07767885688782913, 0.0034018459805066348, 0.08476918798418709], 'clockwise': 1, 'angle': [2.797476247430128, -3.9145259860912374, -4.9079990991035665]}


100%|██████████| 1000/1000 [00:23<00:00, 41.94it/s]
