In [1]:
import sys
import os
from glob import glob

import matplotlib.pyplot as plt
import pandas as pd

# Avoids warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
import tensorflow as tf

import ScalableLib.classifier.Multiband as multiband


2024-09-05 08:23:47.732038: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-05 08:23:47.732071: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-05 08:23:47.733135: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Classifier mode only should ignore physical parameters listed in the input.

In [2]:
# To see if the system regognises the GPU
device = 1
devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.set_visible_devices(devices[device], 'GPU')
tf.config.experimental.set_memory_growth(device=devices[device], enable=True)

device_name = tf.config.experimental.get_device_details(devices[device])['device_name']
print("Using {}".format(device_name))

Using NVIDIA GeForce RTX 2080 Ti


Find the different folds and train a model using the stored data.

In [3]:
survey = 'PanStarrs'
path = os.path.join('../../02_CreateRecords/', survey, 'Folds/Fold_*',)
folds = glob(path)
folds.sort()
folds

['../../02_CreateRecords/PanStarrs/Folds/Fold_1',
 '../../02_CreateRecords/PanStarrs/Folds/Fold_2',
 '../../02_CreateRecords/PanStarrs/Folds/Fold_3',
 '../../02_CreateRecords/PanStarrs/Folds/Fold_4',
 '../../02_CreateRecords/PanStarrs/Folds/Fold_5',
 '../../02_CreateRecords/PanStarrs/Folds/Fold_6',
 '../../02_CreateRecords/PanStarrs/Folds/Fold_7']

Create folder results

In [4]:
if not os.path.exists('./Results'):
    os.mkdir('./Results')

Define the arguments for all the models.

In [5]:
train_args = {
            'hidden_size_bands':[64, 64, 64],
            'hidden_size_central':[64, 64, 64],
            'fc_layers_bands':[128,128],
            'fc_layers_central':[128,128], # Neurons of each layer
            'regression_size':[128, 128],#each element is a layer with that size.
            'buffer_size':10000,
            'epochs':1000,
            'num_threads':7,
            'batch_size':512,
            'dropout':0.35,
            'lr':[[5e-3]*5, 2.5e-3], # [[band1, band2], central]
            'val_steps':50,
            'max_to_keep':0, # Not Used 
            'steps_wait':500, 
            'use_class_weights':False,# Not Used
            'mode' : 'classifier+regression',
            }
# loss_weights = {'Class':300.0, 'T_eff':1.0,'Radius':1e0}
loss_weights = {'Class':300.0, 'T_eff':20.0,'Radius':1e0}

callbacks_args = {'patience': 20,
                  'mode':'max',
                  'restore_best_weights':True,
                  'min_delta': 0.001
                 }
train_args_specific={
                    'phys_params': ['T_eff', 'Radius'],
                    'use_output_bands' : True,  # Working
                    'use_output_central' : False, # Not used
                    'use_common_layers' : False, # NOT Working
                    'bidirectional_central' : False,# Working
                    'bidirectional_band' : False,# Not Working
                    'layer_norm_params' : None, # Used to normalyze common layers
                    'use_gated_common' : False, # Working
                    'l1':0.0,
                    'l2':0.0,
                    'N_skip' : 2, # Cannot be greater than the number of timesteps
                    'use_raw_input_central': False,
                    'train_steps_central' : 1,
                    'print_report' : True,
                    'loss_weights_central' : loss_weights,
                    'callbacks_args':callbacks_args    
                    }



In [6]:
for fold in folds:
    tf.keras.backend.clear_session()
    # Set the fold path
    base_dir = fold+'/'
    
    # Set the save path for this fold. Create folder if needed
    path_results_fold = fold.replace('../../02_CreateRecords/'+survey+'/', './').replace('/Folds/', '/Results/')
    if not os.path.exists(path_results_fold):
        os.mkdir(path_results_fold)
        
    train_args_specific['save_dir'] = path_results_fold
    train_args_specific['metadata_pre_path'] = base_dir+'metadata_preprocess.json'
    train_args_specific['path_scalers'] =  os.path.join(fold,'scalers')
    # Define the train args
    train_args = {**train_args, **train_args_specific}


    train_files = base_dir+'train/*.tfrecord'
    val_files = base_dir+'val/*.tfrecord'
    test_files = base_dir+'test/*.tfrecord'

    new = multiband.Network()
    new.train(train_args, train_files, val_files, test_files)
    new.train_loop()

./Results/Fold_1/Models/20240905-0824
Start training


I0000 00:00:1725539083.695806 1793112 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Early Stopping
              precision    recall  f1-score   support

  DSCT_SXPHE       0.78      0.60      0.68       381
     MIRA_SR       0.96      0.94      0.95       788
        RRAB       0.79      0.78      0.78      2000
         RRC       0.76      0.78      0.77      2000
         RRD       0.01      0.02      0.01        53
       T2CEP       0.36      0.24      0.29        38

    accuracy                           0.78      5260
   macro avg       0.61      0.56      0.58      5260
weighted avg       0.79      0.78      0.79      5260

{'R2': {'T_eff': 0.7485413397964977, 'Radius': 0.4359158641385813}, 'RMSE': {'T_eff': 557.95605, 'Radius': 25.940954}}
./Results/Fold_2/Models/20240905-0941
Start training
Early Stopping
              precision    recall  f1-score   support

  DSCT_SXPHE       0.78      0.60      0.68       381
     MIRA_SR       0.97      0.94      0.95       788
        RRAB       0.79      0.78      0.78      2000
         RRC       0.75      0.79     