In [1]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt

from scipy.signal import spectrogram, stft, istft, check_NOLA
import lightgbm as lgb
from lightgbm import LGBMRegressor

import ray
ray.init(include_dashboard=True, num_cpus = 8, dashboard_host='0.0.0.0')

plt.style.use('ggplot')

2023-03-26 14:28:47,296	INFO worker.py:1553 -- Started a local Ray instance.


# UTILITY FUNCTIONS

In [2]:
subsampleFreq = 64   # FINAL FREQUENCY IN HERTZ AFTER SUBSAMPLING
secondsInWindow = 1.
nperseg = subsampleFreq * secondsInWindow
noverlap = nperseg - 1
window = ('tukey', .25)

In [3]:
# CONVERT STFT FROM R,THETA TO COMPLEX
# dim(z) = (# timesteps, # freq bins x 2 (2 reals = 1 complex))

def rThetaToComplex(z):
    rows, cols = z.shape
    shortTermFourier = np.zeros((rows, cols // 2), dtype=np.csingle)
    for i in range(rows):
        for k in range(cols // 2):
            r = z[i,k]
            theta = z[i, (k + cols // 2)]
            shortTermFourier[i,k] =  r * np.exp(complex(0, theta))
    return shortTermFourier.transpose() # dim = (# freq bins, # timepoints)

# CONVERT REAL STFT TO COMPLEX STFT, INVERT TO GET THE ISTFT (I.E. TIME SERIES), THEN PLOT

def realSTFTtoTimeSeries(realSTFT):
    shortTermFourierComplex = rThetaToComplex(realSTFT)
    times, inverseShortFourier = istft(shortTermFourierComplex, 
                                       fs=subsampleFreq, 
                                       window=window, 
                                       nperseg=nperseg, 
                                       noverlap=noverlap)
    return times, inverseShortFourier

# LOAD NUMPY ARRAYS

In [None]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/trainTestRTheta.npz'
npzfile = np.load(arraySavePath)
x_trainRTheta = npzfile['x_trainRTheta']
x_validRTheta = npzfile['x_validRTheta'] 
y_trainRTheta = npzfile['y_trainRTheta'] 
y_validRTheta = npzfile['y_validRTheta']

_,nY = y_validRTheta.shape

# LIGHTGBM

In [None]:
# NATIVE LIGHTGBM API VERSION

modelDirectory = '/blue/gkalamangalam/jmark.ettinger/predictScalp/lgbModels/'

@ray.remote
def train_index(x_trainRTheta, y_trainRTheta, x_validRTheta, y_validRTheta, index, init_model, num_iterations):
    metric = 'mse'
    early_stopping_rounds = 5

    param = {'metric': metric, 
             'num_iterations': num_iterations,
             'early_stopping_rounds': early_stopping_rounds,
             'first_metric_only': True}

    lgbTrain = lgb.Dataset(x_trainRTheta, label=y_trainRTheta[:,index], free_raw_data=False)
    lgbValid = lgbTrain.create_valid(x_validRTheta, label=y_validRTheta[:,index])

    bst = lgb.train(param, 
                    lgbTrain, 
                    valid_sets=[lgbValid, lgbTrain],
                    init_model=init_model)
    
    bst.save_model(modelDirectory + 'lgbModel_%s.txt' % str(index), num_iteration=bst.best_iteration)
    return bst

def lgbPredictIndex(x, index):
    bst = lgb.Booster(model_file=modelDirectory + 'lgbModel_%s.txt' % str(index))
    ypred = bst.predict(x, num_iteration=bst.best_iteration)
    return ypred

def lgbPredict(x):
    return np.array([lgbPredictIndex(x, index) for index in range(nY)]).transpose()

xTrainRay = ray.put(x_trainRTheta)
yTrainRay = ray.put(y_trainRTheta) 
xValidRay = ray.put(x_validRTheta)
yValidRay = ray.put(y_validRTheta)

In [None]:
# TRAIN LIGHTGBM MODEL

continueFlag = True
num_iterations = 50

result_ids = []
for index in range(nY):
    if continueFlag:
        init_model = modelDirectory + 'lgbModel_%s.txt' % str(index)
    else:
        init_model = None
        
    result_ids.append(train_index.remote(xTrainRay, 
                                         yTrainRay, 
                                         xValidRay, 
                                         yValidRay, 
                                         index, 
                                         init_model, 
                                         num_iterations))
#results = ray.get(result_ids)

# Plot results of fit

In [None]:
# PLOT PREDICTION VERSUS TRUTH

trainPlotFlag = False


freqPredict = lgbPredict(x)
title = 'LightGBM: ' + trainTitle
    
if trainPlotFlag:
    x = x_trainRTheta
    y = y_trainRTheta
    trainTitle = 'train'
else:
    x = x_validRTheta
    y = y_validRTheta
    trainTitle = 'valididation'

_, yPred = realSTFTtoTimeSeries(freqPredict)
_, yTrue = realSTFTtoTimeSeries(y)

plt.figure()
plt.plot(yPred, label='predict')
plt.plot(yTrue, label='true')
plt.legend()
plt.title(title)
plt.show()