In [None]:
import flowcode
import processing
import res_flow_vis as visual
import device_use

import torch
import numpy as np
import time

In [None]:
filename = "NC_CL2_8_8_16_4"

In [None]:
#Initiate processing
mp = processing.Processor()

In [None]:
#Get data raw data
Data = mp.get_data("Old/g8.26e11.npy")

In [None]:
#Clean data
Data_const = mp.constrain_data(Data)

In [None]:
#Excluded percentage:
print(f"Excluded percentage: {100*(1-len(Data_const)/len(Data)):.1f}%")

In [None]:
#Right device
device = device_use.device_use

In [None]:
#Instantiate the model
model = flowcode.NSFlow(8, 10, 0, flowcode.NSF_CL2, K=8, B=3, network = flowcode.MLP, network_args=(16,4,0.2))
model = model.to(device)

In [None]:
Data_flow = mp.Data_to_flow(Data_const)

In [None]:
#Save relevant data to the drive for external python file (device needs to be GPU) to do the training in background...
torch.save(Data_flow, "NC_trainer/data_NC_trainer.pth")
torch.save(model, "NC_trainer/model_NC_trainer.pth")
np.save("NC_trainer/params_NC_trainer.npy", np.array([400,0.0004,1024,0.9985]))
np.save("NC_trainer/filename_NC_trainer.npy", filename)

In [None]:
#Train the model
#list to collect loss into
my_loss_saver = []
start = time.perf_counter()
flowcode.train_flow(model, Data_flow, np.array([]), 100, lr=0.016, batch_size=1024, loss_saver=my_loss_saver, gamma=0.998)
end = time.perf_counter()
torch.save(model.state_dict(), f"saves/{filename}.pth")
np.save(f"saves/loss_{filename}.npy",np.array(my_loss_saver+[end-start]))

In [None]:
#torch.save(model.state_dict(), f"saves/{filename}.pth")
#np.save(f"saves/loss_{filename}.npy",np.array(my_loss_saver+[358*60]))

In [None]:
#Load in training results:
model.load_state_dict(torch.load(f"saves/{filename}.pth"))
loss_results = np.load(f"saves/loss_{filename}.npy")
loss_results, tot_time = loss_results[:-1], loss_results[-1]/60

In [17]:
#Sample model
N_samples = Data_const.shape[0]
start = time.perf_counter()
flow_sample = mp.sample_to_Data(mp.sample_flow(model, N_samples, 700000))
print(f"Sampling time: {time.perf_counter()-start:.2f}s")

Sampling time: 0.72s


In [None]:
#Get result plots
visual.get_result_plots(Data_const, flow_sample, label=filename, N_unit="massperkpc", Mass=(mp.M_stars,mp.M_stars))

In [None]:
#Get loss curve
visual.loss_plot(loss_results, tot_time=tot_time, savefig=filename)

In [None]:
import importlib
importlib.reload(visual)