In [None]:
import numpy as np
import os
import random
import torch
from PD.PlasmaDataset import PlasmaDataset
from PML.PlasmaModel import PlasmaModel
from PML.PMLParameters import PMLParameters

In [None]:
DATA_SPLIT = [0.5, 0.3, 0.2] #train/test/val splits
DATA_FRAC = 1 #fraction of files to load data from
HDF5_DATA_DIR = "./jtext_data/low_freq" #data to source hdf5 files from
ORG_DATA_DIR = "./jtext_org" #directory for exporting data CSVs
MODEL_DIR = "./models" #directory to save trained models
MODEL_COUNT = 10 #number of models to randomly generate and train
FEATS = ['dx', 'dy'] #model features
HP_SEARCH = 'random' #hyperparameter search mode

In [None]:
#initialize processed data with some dummy training feature data directories for testing (uses dx and dy features)
PROCESSED_DATA = {
    "train_norm"   : './jtext_org/train/train-norm-dxdy-dummy.csv',
    "train_labels" : './jtext_org/train/train-labels-dxdy-dummy.csv',
    "test_norm"    : './jtext_org/test/test-norm-dxdy-dummy.csv',
    "test_labels"  : './jtext_org/test/test-labels-dxdy-dummy.csv',
    "val_norm"     : './jtext_org/val/val-norm-dxdy-dummy.csv',
    "val_labels"   : './jtext_org/val/val-labels-dxdy-dummy.csv'
}

In [None]:
#designate feature ranges and static parameters for hyperparameter search
#note to current user - grid search is currently borked - don't use (plus not efficient)
PARAMETER_RANGES = {
    'lr'            : [0.001, 0.01], #learning rate range
    'lstm_layers'   : [[200,400], [80,120]], #lstm layer count and hidden size ranges
    'linear_layers' : [[100,200], [100,150]], #linear layer count and neuron ranges
    'dropout_layers': [[0.05, 0.1]], #dropout layer count and dropout probabilities
}
STATIC_PARAMETERS = {
    'batch_size'       : 8,
    'criterion'        : torch.nn.BCEWithLogitsLoss(), #uses binary cross entropy loss
    'epochs'           : 10, #number of training epochs/model
    'init'             : torch.nn.init.xavier_normal_,
    'input_size'       : len(FEATS), #set input size to # of features
    'lstm_activation'  : torch.nn.functional.tanh, #LSTM layers activation function
    'linear_activation': torch.nn.functional.relu, #Linear layers activation function
    'optimizer'        : torch.optim.Adam, #use ADAM optimizer
    'output_activation': torch.nn.functional.sigmoid, #output neuron activation
}

In [None]:
def makeDataset(dataset:"PlasmaDataset", split:list, features:list, frac:float=1, preview=False):
    dataset.initialize() #creates train/test/val subdatasets
    dataset.sourceFiles(data_split = split, data_frac = frac) #initialize split/datafrac and gather hdf5 file info
    dataset.sourceData(features) #source specified feature data from files
    dataset.calcStats() #calculate data statistics from raw feature data
    dataset.normalize() #use data statistics to normalize data
    dataset.saveCSV(['train', 'test', 'val', 'stats'], name="dxdy-dummy") #export dataset to model-loadable CSV
    if preview:
        dataset.preview() #preview datasets
    dataset.deleteDatasets() #remove dataset from memory (since saved to CSV)

In [None]:
def makeModels(
                modeler:"PlasmaModel", 
                processed_data:dict, 
                parameter_ranges:dict, 
                static_parameters:dict, 
                model_count:int, 
                searchmode:str
              ):
    modeler.makeHyperparameterSet(
                    static_params=static_parameters, 
                    param_ranges=parameter_ranges, 
                    count=model_count, 
                    mode=searchmode
    )
    modeler.prepareData(processed_data)
    modeler.runModelSearch()

In [None]:
JTEXT_LOW = PlasmaDataset(org_directory = ORG_DATA_DIR, h5_source = HDF5_DATA_DIR)
MODELER = PlasmaModel(MODEL_DIR, static_parameters=STATIC_PARAMETERS)
makeDataset(JTEXT_LOW, split=DATA_SPLIT, frac=DATA_FRAC, features=FEATS)
makeModels(
    modeler=MODELER,
    processed_data=PROCESSED_DATA,
    parameter_ranges=PARAMETER_RANGES,
    static_parameters=STATIC_PARAMETERS,
    model_count=MODEL_COUNT,
    searchmode=HP_SEARCH
)