In [None]:
import flowcode
import processing
import res_flow_vis as visual
import device_use
import externalize as ext

import torch
import numpy as np

In [None]:
#Filename associated with this specific run
filename = "leavout_MttZ_MWs_CL2_24_10_512_8_lo15" #"leavout_MttZ_MWs_CL2_24_10_512_8_lo5"

In [None]:
#Initiate a processor to handle data
mpc = processing.Processor_cond(N_min=500, percentile2=95)

In [None]:
#Load raw data
Data, N_stars, M_stars, M_dm = mpc.get_data("all_sims")

In [None]:
#Clean data
Data_const, N_stars_const, M_stars_const, M_dm_const = mpc.constraindata(Data, M_dm)

In [None]:
#Chose a subset of the data
#Subset to view for comaprison
Data_sub_v, N_stars_sub_v, M_stars_sub_v, M_dm_sub_v = mpc.choose_subset(Data_const, N_stars_const, M_stars_const, M_dm_const, use_fn = ext.MW_like_galaxy, cond_fn=ext.cond_M_stars_2age_avZ)

#Subset to train on (e.g. leave one out):
leavout_idices = 15
leavout_fn = ext.construct_MW_like_galaxy_leavout(M_dm_sub_v[leavout_idices])
Data_sub, N_stars_sub, M_stars_sub, M_dm_sub = mpc.choose_subset(Data_const, N_stars_const, M_stars_const, M_dm_const, use_fn = leavout_fn, cond_fn=ext.cond_M_stars_2age_avZ)

In [None]:
#Choose device
device = "cpu" #device_use.device_use

In [None]:
#Hyperparameters of the flow
LAYER_TYPE = flowcode.NSF_CL2
N_LAYERS = 24
COND_INDS = np.array([10,11,12,13])
DIM_COND = COND_INDS.shape[0]
DIM_NOTCOND = Data_sub[0].shape[1] - DIM_COND
SPLIT = 0.5
K = 10
B = 3
BASE_NETWORK = flowcode.MLP
BASE_NETWORK_N_LAYERS = 8
BASE_NETWORK_N_HIDDEN = 512
BASE_NETWORK_LEAKY_RELU_SLOPE = 0.2

SPLIT = {"split":SPLIT} if LAYER_TYPE == flowcode.NSF_CL else {}

In [None]:
#Instantiate the model
model = flowcode.NSFlow(N_LAYERS, DIM_NOTCOND, DIM_COND, LAYER_TYPE, **SPLIT, K=K, B=B, network=BASE_NETWORK, network_args=(BASE_NETWORK_N_HIDDEN,BASE_NETWORK_N_LAYERS,BASE_NETWORK_LEAKY_RELU_SLOPE))
model = model.to(device)
#Load pre-trained model
#model.load_state_dict(torch.load("saves/leavout_M_star_MWs_CL2_24_10_512_8_lo1.pth"))

In [None]:
#Training hyperparameters
N_EPOCHS = 12
INIT_LR = 0.00009
GAMMA = 0.998
BATCH_SIZE = 1024

#Define indices for preprocessing
LOG_LEARN = np.array([10])
SMOOTHEN_MAX = np.array([7,8,9])
SMOOTHEN_MIN = np.array([6,9])

#Define functions for preprocessing
max_s = ext.tanh_smooth("max")
min_s = ext.tanh_smooth("min")

#Define collections
transformations = (np.log10, )#max_s.smooth, min_s.smooth)
trf_indices = (LOG_LEARN, )#SMOOTHEN_MAX, SMOOTHEN_MIN)
transformations_inv = (lambda x: 10**x, )#max_s.smooth_inv, min_s.smooth_inv)

In [None]:
#Prepare data for flow
Data_flow = mpc.Data_to_flow(mpc.diststack(Data_sub), transformations, trf_indices, transformations_inv)

In [None]:
#Save relevant data to the drive for external python file (device needs to be GPU) to do the training in background...
torch.save(Data_flow, "cond_trainer/data_cond_trainer.pth")
torch.save(model, "cond_trainer/model_cond_trainer.pth")
np.save("cond_trainer/params_cond_trainer.npy", np.append(COND_INDS,np.array([N_EPOCHS,INIT_LR,BATCH_SIZE,GAMMA])))
np.save("cond_trainer/filename_cond_trainer.npy", filename)

In [None]:
#Start background training
import subprocess
subprocess.Popen("nohup python3 cond_trainer.py &", shell=True)

In [None]:
#...OR train here
import time
train_loss_saver = []
start = time.perf_counter()
flowcode.train_flow(model, Data_flow, COND_INDS, N_EPOCHS, lr=INIT_LR, batch_size=BATCH_SIZE, loss_saver=train_loss_saver, gamma=GAMMA)
end = time.perf_counter()
torch.save(model.state_dict(), f"saves/{filename}.pth")
np.save(f"saves/loss_{filename}.npy",np.array(train_loss_saver+[end-start]))

In [None]:
#Load in training results:
model.load_state_dict(torch.load(f"saves/{filename}.pth", map_location=device))
loss_results = np.load(f"saves/loss_{filename}.npy")
loss_results, tot_time = loss_results[:-1], loss_results[-1]/60

In [None]:
#Get a sample from the flow
use_GPUs = [1,2,3,4,6,7,8]
import time
start = time.perf_counter()
#Set a condition for the sample
condition = mpc.diststack(Data_sub_v)[:,COND_INDS]

#Get sample
flow_sample = mpc.galaxysplit(mpc.sample_to_Data(mpc.sample_Conditional(model, COND_INDS, condition, split_size=int(6e5), GPUs=use_GPUs)), N_stars_sub_v)
#Format in minutes and seconds
print(f"Time to sample: {int((time.perf_counter()-start)/60)} minutes and {int((time.perf_counter()-start)%60)} seconds")

In [None]:
### Visualize data

In [None]:
#Get multiple galaxy plot
visual.plot_conditional_2(Data_sub_v, M_stars_sub_v, flow_sample, M_stars_sub_v, type="N", label=filename, N_unit="massperkpc", color_pass="first", global_grid=True)

In [None]:
#Get comparison plot of single galaxy

visual.get_result_plots(Data_sub_v[15], flow_sample[15], label=filename, format_="pdf")

In [None]:
visual.plot_conditional_histograms(flow_sample, M_stars_sub, label = filename, log=True)

In [None]:
visual.loss_plot(loss_results, tot_time=tot_time, savefig=filename)

In [None]:
import importlib
importlib.reload(visual)