# training

In [1]:
## ML imports
import tensorflow as tf
import keras
from keras import layers

from scripts import WMSE, InversePCA

## other imports
import pandas as pd
import numpy as np
import os

##### poke gpu
os.environ["CUDA_VISIBLE_DEVICES"]="1"

physical_devices = tf.config.list_physical_devices("GPU") 

gpu0usage = tf.config.experimental.get_memory_info("GPU:0")["current"]

print("Current GPU usage:\n"
     + " - GPU0: " + str(gpu0usage) + "B\n")

2026-01-23 11:04:04.891588: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-01-23 11:04:04.891616: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-01-23 11:04:04.892731: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-01-23 11:04:04.898235: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Current GPU usage:
 - GPU0: 0B



2026-01-23 11:04:06.084246: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 18447 MB memory:  -> device: 0, name: NVIDIA RTX A4500, pci bus id: 0000:61:00.0, compute capability: 8.6


## load in data

In [2]:
## def data dir
repo_dir = '/home/oxs235/datastorage/repos_data/ojscutt/fitchpork-train/'

data_dir = repo_dir + 'data/'

## load in train and validation
df_train = pd.read_hdf(data_dir+'bob_train_slim_fp.h5', key='df')
df_val = pd.read_hdf(data_dir+'bob_val_slim_fp.h5', key='df')

## load in PCA variables
pca_mean = np.loadtxt(data_dir+'pca_mean_fp.csv')
pca_components = np.loadtxt(data_dir+'pca_components_fp.csv')

## load in WMSE weights
WMSE_weights = np.loadtxt(data_dir+'WMSE_weights_fp.csv').tolist()

## define inputs and outputs

In [3]:
#### define inputs
inputs = ['log_initial_mass_std', 'log_initial_Zinit_std', 'log_initial_Yinit_std', 'log_initial_MLT_std', 'log_star_age_std']

#### define outputs
classical_outputs = ['log_radius_std', 'log_luminosity_std', 'star_feh_std']
astero_outputs = [f'log_nu_0_{i+1}_std' for i in range(5,40)]

outputs = classical_outputs+astero_outputs

# train pitchfork!

In [None]:
## model_name
model_name = 'fitchpork-duo'

tb_dir = repo_dir + 'logs/fit/'
model_dir = repo_dir + 'models/'

## architecture variables
stem_d_layers = 2
stem_d_units = 128

ctine_d_layers = 2
ctine_d_units = 64

atine_d_layers = 6
atine_d_units = 128

initial_lr = 0.001


tf.keras.backend.clear_session()
######## stem
#### input
stem_input = keras.Input(shape=(len(inputs),))

#### dense layers
for stem_d_layer in range(stem_d_layers):
    if stem_d_layer == 0:
        stem = layers.Dense(stem_d_units, activation='elu')(stem_input)
    else:
        stem = layers.Dense(stem_d_units, activation='elu')(stem)

######## classical tine
#### dense layers
for ctine_d_layer in range(ctine_d_layers):
    if ctine_d_layer == 0:
        ctine = layers.Dense(ctine_d_units, activation='elu')(stem)
    else:
        ctine = layers.Dense(ctine_d_units, activation='elu')(ctine)

#### output
ctine_out = layers.Dense(len(classical_outputs), name='classical_outs')(ctine)


######## astero tine
#### dense layers
for atine_d_layer in range(atine_d_layers):
    if atine_d_layer == 0:
        atine = layers.Dense(atine_d_units, activation='elu')(stem)
    else:
        atine = layers.Dense(atine_d_units, activation='elu')(atine)

#### output
atine = layers.Dense(int(len(pca_components)))(atine)
atine_out = InversePCA(pca_comps = pca_components, pca_mean = pca_mean, name='asteroseismic_outs')(atine)

######## construct and fit
model = keras.Model(inputs=stem_input, outputs=[ctine_out, atine_out], name=model_name)

#### compile model
optimizer = tf.keras.optimizers.Adam(learning_rate=initial_lr)
  
model.compile(loss=[WMSE(WMSE_weights[:3]), WMSE(WMSE_weights[3:])], optimizer=optimizer)

#### fit model
def scheduler(epoch, lr):
    if lr < 1e-5:
        return lr
    else:
        return lr * tf.math.exp(-0.00006) #<- changed from -0.00006!

lr_callback = tf.keras.callbacks.LearningRateScheduler(scheduler, verbose=0)
                                                   
cp_callback = tf.keras.callbacks.ModelCheckpoint(model_dir + model_name + ".h5",
                                                 monitor= 'val_loss',
                                                 save_best_only= True,
                                                 save_freq='epoch')    

tb_callback = tf.keras.callbacks.TensorBoard(log_dir=tb_dir+model_name) 

history = model.fit(df_train[inputs],
                    [df_train[classical_outputs],df_train[astero_outputs]],
                    validation_data=(df_val[inputs],[df_val[classical_outputs], df_val[astero_outputs]]),
                    batch_size=32768,
                    verbose=1,
                    epochs=100000,
                    callbacks=[lr_callback, cp_callback, tb_callback],
                    shuffle=True
                   ) 

2026-01-23 11:04:12.276536: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory


Epoch 1/100000


2026-01-23 11:04:13.924875: I external/local_xla/xla/service/service.cc:168] XLA service 0x7ef0918e3a10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2026-01-23 11:04:13.924914: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA RTX A4500, Compute Capability 8.6
2026-01-23 11:04:13.929886: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2026-01-23 11:04:13.943007: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8904
I0000 00:00:1769166254.023057 1266698 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch 2/100000
15/68 [=====>........................] - ETA: 0s - loss: 5314307.5000 - classical_outs_loss: 305.5121 - inverse_pca_loss: 5314002.0000

  saving_api.save_model(


Epoch 3/100000
Epoch 4/100000
Epoch 5/100000
Epoch 6/100000
Epoch 7/100000
Epoch 8/100000
Epoch 9/100000
Epoch 10/100000
Epoch 11/100000
Epoch 12/100000
Epoch 13/100000
Epoch 14/100000
Epoch 15/100000
Epoch 16/100000
Epoch 17/100000
Epoch 18/100000
Epoch 19/100000
Epoch 20/100000
Epoch 21/100000
Epoch 22/100000
Epoch 23/100000
Epoch 24/100000
Epoch 25/100000
Epoch 26/100000
Epoch 27/100000
Epoch 28/100000
Epoch 29/100000
Epoch 30/100000
Epoch 31/100000
Epoch 32/100000
Epoch 33/100000
Epoch 34/100000
Epoch 35/100000
Epoch 36/100000
Epoch 37/100000
Epoch 38/100000
Epoch 39/100000
Epoch 40/100000
Epoch 41/100000
Epoch 42/100000
Epoch 43/100000
Epoch 44/100000
Epoch 45/100000
Epoch 46/100000
Epoch 47/100000
Epoch 48/100000
Epoch 49/100000
Epoch 50/100000
Epoch 51/100000
Epoch 52/100000
Epoch 53/100000
Epoch 54/100000
Epoch 55/100000
Epoch 56/100000
Epoch 57/100000
Epoch 58/100000
Epoch 59/100000
Epoch 60/100000
Epoch 61/100000
Epoch 62/100000
Epoch 63/100000
Epoch 64/100000
Epoch 65/100000