In [1]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt

from scipy.signal import spectrogram, stft, istft, check_NOLA

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torchvision.transforms import ToTensor
import os
import neptune
from neptune.utils import stringify_unsupported

import scalpDeepModels as sdm

import importlib

plt.style.use('ggplot')

# PARAMETERS - GENERAL

In [2]:
stftSavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/freqRTheta.npz'
timeDomainSavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/timeDomain.npz'
timeFreqSavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/timeFreqRTheta.npz'

modelPath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/pytorchModels/model.pth'

neptuneProject = 'jettinger35/predictScalp'
api_token = os.environ.get('NEPTUNE_API_TOKEN')

subsampleFreq = 128   # FINAL FREQUENCY IN HERTZ AFTER SUBSAMPLING
secondsInWindow = 1
nperseg = subsampleFreq * secondsInWindow
noverlap = nperseg - 1
window = ('tukey', .25)

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


# PARAMETERS - TRAINING

In [3]:
epochs = 5000
batch_size = 1024
learningRate = 1e-3
#loss_fn = nn.MSELoss()
loss_fn = nn.L1Loss()
optChoice = 'adam'

patience = 5000

# UTILITY FUNCTIONS

In [4]:
# CONVERT STFT FROM R,THETA TO COMPLEX
# dim(z) = (# timesteps, # freq bins x 2 (2 reals = 1 complex))

def rThetaToComplex(z):
    rows, cols = z.shape
    shortTermFourier = np.zeros((rows, cols // 2), dtype=np.csingle)
    for i in range(rows):
        for k in range(cols // 2):
            r = z[i,k]
            theta = z[i, (k + cols // 2)]
            shortTermFourier[i,k] =  r * np.exp(complex(0, theta))
    return shortTermFourier.transpose() # dim = (# freq bins, # timepoints)

# CONVERT REAL STFT TO COMPLEX STFT, INVERT TO GET THE ISTFT (I.E. TIME SERIES), THEN PLOT

def realSTFTtoTimeSeries(realSTFT):
    shortTermFourierComplex = rThetaToComplex(realSTFT)
    times, inverseShortFourier = istft(shortTermFourierComplex, 
                                       fs=subsampleFreq, 
                                       window=window, 
                                       nperseg=nperseg, 
                                       noverlap=noverlap)
    return times, inverseShortFourier

# LOAD NUMPY DATA ARRAYS

In [5]:
dataSwitch = 'time'

if dataSwitch == 'freq':
    # STFT DATA

    npzfile = np.load(stftSavePath)
    x_trainRTheta = npzfile['x_trainRTheta']
    x_validRTheta = npzfile['x_validRTheta'] 
    y_trainRTheta = npzfile['y_trainRTheta'] 
    y_validRTheta = npzfile['y_validRTheta']

    trainXTensor = torch.Tensor(x_trainRTheta)
    trainYTensor = torch.Tensor(y_trainRTheta)
    validXTensor = torch.Tensor(x_validRTheta)
    validYTensor = torch.Tensor(y_validRTheta)

elif dataSwitch == 'time':
    # TIME DOMAIN DATA

    npzfile = np.load(timeDomainSavePath)
    xTrainTimeDomain = npzfile['xTrainTimeDomain']
    xValidTimeDomain = npzfile['xValidTimeDomain'] 
    yTrainTimeDomain = npzfile['yTrainTimeDomain'] 
    yValidTimeDomain = npzfile['yValidTimeDomain']

    trainXTensor = torch.Tensor(xTrainTimeDomain)
    trainYTensor = torch.Tensor(yTrainTimeDomain)
    validXTensor = torch.Tensor(xValidTimeDomain)
    validYTensor = torch.Tensor(yValidTimeDomain)
    
elif dataSwitch == 'timeFreq':
    
    npzfile = np.load(timeFreqSavePath)
    xTrain = npzfile['x_trainTimeFreq']
    xValid = npzfile['x_validTimeFreq'] 
    yTrain = npzfile['y_trainTimeFreq'] 
    yValid = npzfile['y_validTimeFreq']

    trainXTensor = torch.Tensor(xTrain)
    trainYTensor = torch.Tensor(yTrain)
    validXTensor = torch.Tensor(xValid)
    validYTensor = torch.Tensor(yValid)

In [6]:
# CREATE PYTORCH DATALOADERS

trainDataset = TensorDataset(trainXTensor,trainYTensor)
trainDataLoader = DataLoader(trainDataset,batch_size=batch_size, shuffle=True)

validDataset = TensorDataset(validXTensor,validYTensor)
validDataLoader = DataLoader(validDataset,batch_size=batch_size, shuffle=True)


print("train: ")
for X, y in trainDataLoader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
    
print("\ntest: ")
for X, y in validDataLoader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

train: 
Shape of X [N, C, H, W]: torch.Size([1024, 5655])
Shape of y: torch.Size([1024, 1]) torch.float32

test: 
Shape of X [N, C, H, W]: torch.Size([1024, 5655])
Shape of y: torch.Size([1024, 1]) torch.float32


# DEFINE OR LOAD THE MODEL

In [17]:
run = neptune.init_run(
    project=neptuneProject,
    api_token=api_token,  
    capture_hardware_metrics=True,
    capture_stderr=True,
    capture_stdout=True,
    with_id="PRED-43"
)

try:
    destinationPathModel = modelPath
    run["model_best"].download(destinationPathModel)
    print("model download success...")
    run.stop()
except Exception as error:
    print("model download failure...")
    print(error)
    run.stop()

https://new-ui.neptune.ai/jettinger35/predictScalp/e/PRED-43
model download success...
Shutting down background jobs, please wait a moment...
Done!
Waiting for the remaining 6 operations to synchronize with Neptune. Do not kill this process.
All 6 operations synced, thanks for waiting!
Explore the metadata in the Neptune app:
https://new-ui.neptune.ai/jettinger35/predictScalp/e/PRED-43/metadata


In [23]:
model = NeuralNetwork(
  (model): Sequential(
    (bn0): BatchNorm1d(5655, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (l0): Linear(in_features=5655, out_features=512, bias=True)
    (r0): ReLU()
    (d0): Dropout(p=0.5, inplace=False)
    (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (l1): Linear(in_features=512, out_features=512, bias=True)
    (r1): ReLU()
    (d1): Dropout(p=0.5, inplace=False)
    (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (l2): Linear(in_features=512, out_features=512, bias=True)
    (r2): ReLU()
    (d2): Dropout(p=0.5, inplace=False)
    (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (l3): Linear(in_features=512, out_features=512, bias=True)
    (r3): ReLU()
    (d3): Dropout(p=0.5, inplace=False)
    (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (l4): Linear(in_features=512, out_features=512, bias=True)
    (r4): ReLU()
    (d4): Dropout(p=0.5, inplace=False)
    (bn5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (l5): Linear(in_features=512, out_features=1, bias=True)
  )
)


Sequential(
  (in): Linear(in_features=5655, out_features=512, bias=True)
  (bn0): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l0): Linear(in_features=512, out_features=512, bias=True)
  (r0): ReLU()
  (d0): Dropout(p=0.5, inplace=False)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l1): Linear(in_features=512, out_features=512, bias=True)
  (r1): ReLU()
  (d1): Dropout(p=0.5, inplace=False)
  (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l2): Linear(in_features=512, out_features=512, bias=True)
  (r2): ReLU()
  (d2): Dropout(p=0.5, inplace=False)
  (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l3): Linear(in_features=512, out_features=512, bias=True)
  (out): Linear(in_features=512, out_features=1, bias=True)
)

SyntaxError: invalid syntax (2801460522.py, line 2)

In [24]:
# DEFINE MODEL

modelLoadFlag = False

if modelLoadFlag == True:
    model = torch.load(modelPath)
    bestTestLoss = sdm.test(validDataLoader, model, loss_fn)
else:
    importlib.reload(sdm) # reload in case we've made any architecture changes
    
    # DEFINE ARCHITECTURE HERE
    inputSize = trainXTensor.shape[1]
    hiddenLayerSizes = [512,512,512,512,512]
    layerDict = sdm.listToOrderedDict_1(inputSize, hiddenLayerSizes)
    #layerDict = sdm.residualAddDict(inputSize, 512, 5)
    #layerDict = sdm.residualConcatDict(inputSize, hiddenLayerSizes)
    model = nn.Sequential(layerDict)
    bestTestLoss = float('inf')
    
print("Number of parameters: ", sdm.count_parameters(model))
model = model.to(device)
print(model)

Number of parameters:  3951105
Sequential(
  (in): Linear(in_features=5655, out_features=512, bias=True)
  (bn0): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l0): Linear(in_features=512, out_features=512, bias=True)
  (r0): ReLU()
  (d0): Dropout(p=0.5, inplace=False)
  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l1): Linear(in_features=512, out_features=512, bias=True)
  (r1): ReLU()
  (d1): Dropout(p=0.5, inplace=False)
  (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l2): Linear(in_features=512, out_features=512, bias=True)
  (r2): ReLU()
  (d2): Dropout(p=0.5, inplace=False)
  (bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (l3): Linear(in_features=512, out_features=512, bias=True)
  (out): Linear(in_features=512, out_features=1, bias=True)
)


# TRAIN (LOG DATA TO NEPTUNE)

In [None]:
if optChoice == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), lr=learningRate)
elif optChoice == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=learningRate)
else:
    optimizer = None
    print('no optimizer chosen...')

run = neptune.init_run(
    project=neptuneProject,
    api_token=api_token,  
    capture_hardware_metrics=True,
    capture_stderr=True,
    capture_stdout=True,
)

PARAMS = {
    "batch_size": batch_size,
    "learning_rate": learningRate,
    "optimizer": optChoice,
    "patience": patience,
    "subsampleFreq": subsampleFreq,
    "secondsInWindow": secondsInWindow,
    "nperseg": nperseg,
    "noverlap": noverlap,
    "window": stringify_unsupported(window),
    "loss_fn": stringify_unsupported(loss_fn),
    "architectureString": str(model),
    "numParameters": sdm.count_parameters(model)
}
run["parameters"] = PARAMS

noImprovementCount = 0

#epochs = 2

try:
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loss = sdm.train(trainDataLoader, model, loss_fn, optimizer, device)
        test_loss = sdm.test(validDataLoader, model, loss_fn, device)

        if test_loss < bestTestLoss:
            noImprovementCount = 0
            bestTestLoss = test_loss
            torch.save(model, modelPath)
            run["model_best"].upload(modelPath)
            run["best_test_loss"] =  bestTestLoss
            run["best_test_epoch"] = t
            print("\nSaved a new best model!\n")
        else:
            noImprovementCount = noImprovementCount + 1

        run["train/loss"].append(train_loss)
        run["test/loss"].append(test_loss)

        if noImprovementCount >= patience:   
            print("Early stopping invoked....")
            break

    run.stop()
    print("Done!")
except:
    run.stop()
    print("Training aborted...")

https://new-ui.neptune.ai/jettinger35/predictScalp/e/PRED-54
Epoch 1
-------------------------------
loss: 0.791030  [ 1024/176128]
Test Error: 
 Avg loss: 10.295714 


Saved a new best model!

Epoch 2
-------------------------------
loss: 10.402260  [ 1024/176128]
Test Error: 
 Avg loss: 2.547617 


Saved a new best model!

Epoch 3
-------------------------------
loss: 2.569907  [ 1024/176128]
Test Error: 
 Avg loss: 4.647707 

Epoch 4
-------------------------------
loss: 3.956735  [ 1024/176128]
Test Error: 
 Avg loss: 3.747100 

Epoch 5
-------------------------------
loss: 3.369908  [ 1024/176128]
Test Error: 
 Avg loss: 1.226932 


Saved a new best model!

Epoch 6
-------------------------------
loss: 1.160034  [ 1024/176128]
Test Error: 
 Avg loss: 2.090100 

Epoch 7
-------------------------------
loss: 1.977009  [ 1024/176128]
Test Error: 
 Avg loss: 2.399533 

Epoch 8
-------------------------------
loss: 1.897463  [ 1024/176128]
Test Error: 
 Avg loss: 1.077463 


Saved a ne

Test Error: 
 Avg loss: 0.440782 

Epoch 63
-------------------------------
loss: 0.419894  [ 1024/176128]
Test Error: 
 Avg loss: 0.432032 

Epoch 64
-------------------------------
loss: 0.416068  [ 1024/176128]
Test Error: 
 Avg loss: 0.436964 

Epoch 65
-------------------------------
loss: 0.396233  [ 1024/176128]
Test Error: 
 Avg loss: 0.435962 

Epoch 66
-------------------------------
loss: 0.411000  [ 1024/176128]
Test Error: 
 Avg loss: 0.440059 

Epoch 67
-------------------------------
loss: 0.401515  [ 1024/176128]
Test Error: 
 Avg loss: 0.417616 


Saved a new best model!

Epoch 68
-------------------------------
loss: 0.399112  [ 1024/176128]
Test Error: 
 Avg loss: 0.417394 


Saved a new best model!

Epoch 69
-------------------------------
File /blue/gkalamangalam/jmark.ettinger/predictScalp/pytorchModels/model.pth changed during upload, restarting upload.
loss: 0.410632  [ 1024/176128]
Test Error: 
 Avg loss: 0.420506 

Epoch 70
-------------------------------
loss

Test Error: 
 Avg loss: 0.412946 

Epoch 138
-------------------------------
loss: 0.380495  [ 1024/176128]
Test Error: 
 Avg loss: 0.434890 

Epoch 139
-------------------------------
loss: 0.414510  [ 1024/176128]
Test Error: 
 Avg loss: 0.414605 

Epoch 140
-------------------------------
loss: 0.343583  [ 1024/176128]
Test Error: 
 Avg loss: 0.415948 

Epoch 141
-------------------------------
loss: 0.359177  [ 1024/176128]
Test Error: 
 Avg loss: 0.416079 

Epoch 142
-------------------------------
loss: 0.347022  [ 1024/176128]
Test Error: 
 Avg loss: 0.437940 

Epoch 143
-------------------------------
loss: 0.369510  [ 1024/176128]
Test Error: 
 Avg loss: 0.418840 

Epoch 144
-------------------------------
loss: 0.357942  [ 1024/176128]
Test Error: 
 Avg loss: 0.421062 

Epoch 145
-------------------------------
loss: 0.379026  [ 1024/176128]
Test Error: 
 Avg loss: 0.428634 

Epoch 146
-------------------------------
loss: 0.361658  [ 1024/176128]
Test Error: 
 Avg loss: 0.43

Test Error: 
 Avg loss: 0.435036 

Epoch 214
-------------------------------
loss: 0.352047  [ 1024/176128]
Test Error: 
 Avg loss: 0.422822 

Epoch 215
-------------------------------
loss: 0.353476  [ 1024/176128]
Test Error: 
 Avg loss: 0.438494 

Epoch 216
-------------------------------
loss: 0.350365  [ 1024/176128]
Test Error: 
 Avg loss: 0.417827 

Epoch 217
-------------------------------
loss: 0.325515  [ 1024/176128]
Test Error: 
 Avg loss: 0.444287 

Epoch 218
-------------------------------
loss: 0.394912  [ 1024/176128]
Test Error: 
 Avg loss: 0.445041 

Epoch 219
-------------------------------
loss: 0.373306  [ 1024/176128]
Test Error: 
 Avg loss: 0.433391 

Epoch 220
-------------------------------
loss: 0.352968  [ 1024/176128]
Test Error: 
 Avg loss: 0.446151 

Epoch 221
-------------------------------
loss: 0.390013  [ 1024/176128]
Test Error: 
 Avg loss: 0.413044 

Epoch 222
-------------------------------
loss: 0.342405  [ 1024/176128]
Test Error: 
 Avg loss: 0.46

Test Error: 
 Avg loss: 0.414295 

Epoch 290
-------------------------------
loss: 0.304983  [ 1024/176128]
Test Error: 
 Avg loss: 0.419369 

Epoch 291
-------------------------------
loss: 0.315136  [ 1024/176128]
Test Error: 
 Avg loss: 0.408983 

Epoch 292
-------------------------------
loss: 0.303111  [ 1024/176128]
Test Error: 
 Avg loss: 0.408930 

Epoch 293
-------------------------------
loss: 0.299806  [ 1024/176128]
Test Error: 
 Avg loss: 0.414634 

Epoch 294
-------------------------------
loss: 0.312400  [ 1024/176128]
Test Error: 
 Avg loss: 0.417148 

Epoch 295
-------------------------------
loss: 0.313937  [ 1024/176128]
Test Error: 
 Avg loss: 0.420056 

Epoch 296
-------------------------------
loss: 0.305063  [ 1024/176128]
Test Error: 
 Avg loss: 0.422166 

Epoch 297
-------------------------------
loss: 0.318786  [ 1024/176128]
Test Error: 
 Avg loss: 0.430743 

Epoch 298
-------------------------------
loss: 0.315312  [ 1024/176128]
Test Error: 
 Avg loss: 0.42

Test Error: 
 Avg loss: 0.449450 

Epoch 366
-------------------------------
loss: 0.339558  [ 1024/176128]
Test Error: 
 Avg loss: 0.420901 

Epoch 367
-------------------------------
loss: 0.301574  [ 1024/176128]
Test Error: 
 Avg loss: 0.445018 

Epoch 368
-------------------------------
loss: 0.333489  [ 1024/176128]
Test Error: 
 Avg loss: 0.451808 

Epoch 369
-------------------------------
loss: 0.355386  [ 1024/176128]
Test Error: 
 Avg loss: 0.431539 

Epoch 370
-------------------------------
loss: 0.328301  [ 1024/176128]
Test Error: 
 Avg loss: 0.425313 

Epoch 371
-------------------------------
loss: 0.292490  [ 1024/176128]
Test Error: 
 Avg loss: 0.434550 

Epoch 372
-------------------------------
loss: 0.309259  [ 1024/176128]
Test Error: 
 Avg loss: 0.418084 

Epoch 373
-------------------------------
loss: 0.292752  [ 1024/176128]
Test Error: 
 Avg loss: 0.420031 

Epoch 374
-------------------------------
loss: 0.304241  [ 1024/176128]
Test Error: 
 Avg loss: 0.41

Test Error: 
 Avg loss: 0.448590 

Epoch 442
-------------------------------
loss: 0.336440  [ 1024/176128]
Test Error: 
 Avg loss: 0.417251 

Epoch 443
-------------------------------
loss: 0.281878  [ 1024/176128]
Test Error: 
 Avg loss: 0.451249 

Epoch 444
-------------------------------
loss: 0.311995  [ 1024/176128]
Test Error: 
 Avg loss: 0.437844 

Epoch 445
-------------------------------
loss: 0.307039  [ 1024/176128]
Test Error: 
 Avg loss: 0.423368 

Epoch 446
-------------------------------
loss: 0.287223  [ 1024/176128]
Test Error: 
 Avg loss: 0.438763 

Epoch 447
-------------------------------
loss: 0.328052  [ 1024/176128]
Test Error: 
 Avg loss: 0.425937 

Epoch 448
-------------------------------
loss: 0.327206  [ 1024/176128]
Test Error: 
 Avg loss: 0.422603 

Epoch 449
-------------------------------
loss: 0.287013  [ 1024/176128]
Test Error: 
 Avg loss: 0.433052 

Epoch 450
-------------------------------
loss: 0.277643  [ 1024/176128]
Test Error: 
 Avg loss: 0.42

# PLOT RESULTS OF FIT

In [None]:
# PLOT PREDICTION VERSUS TRUTH

trainPlotFlag = True
    
if trainPlotFlag:
    x = trainXTensor
    trainTitle = 'train'
else:
    x = validXTensor
    trainTitle = 'valididation'

model.to('cpu')
predict = model(x).cpu().detach().numpy()
model.to(device)

if predict.shape[1] == 1:
    yPred = predict[:,0]
    if trainPlotFlag:
        yTrue = yTrainTimeDomain[:,0]
    else:
        yTrue = yValidTimeDomain[:,0]
else:
    _, yPred = realSTFTtoTimeSeries(predict)
    if trainPlotFlag:
        y = y_trainRTheta
        _, yTrue = realSTFTtoTimeSeries(y)
    else:
        y = y_validRTheta
        _, yTrue = realSTFTtoTimeSeries(y)
        

lossTemp = loss_fn(torch.tensor(yPred), torch.tensor(yTrue)).item()
title = 'Data: ' + trainTitle + ' (loss: %s)' % str(lossTemp)
plt.figure()
plt.plot(yPred, label='predict')
plt.plot(yTrue, label='true')
plt.legend()
plt.title(title)
plt.show()

# SCRATCH

In [12]:
from scipy.signal import spectrogram, stft, istft, check_NOLA

fs = 1
nperseg = 32
noverlap = 31
#windowType = ('tukey', .25)
windowType = np.ones(nperseg)


a = np.random.rand(100)
f, t, S = stft(a, fs=fs, window=windowType, nperseg=nperseg, noverlap=noverlap)

b = torch.stft(torch.tensor(a), 
               n_fft = nperseg, 
               hop_length = 1, 
               return_complex=True, 
               normalized=False, 
               onesided=True, 
               pad_mode='constant').numpy()

np.abs(np.divide(b,S))

array([[32., 32., 32., ..., 32., 32., 32.],
       [32., 32., 32., ..., 32., 32., 32.],
       [32., 32., 32., ..., 32., 32., 32.],
       ...,
       [32., 32., 32., ..., 32., 32., 32.],
       [32., 32., 32., ..., 32., 32., 32.],
       [32., 32., 32., ..., 32., 32., 32.]])

In [7]:
# HOW TO GRAB DATA FROM NEPTUNE

project = neptune.init_project(project="jettinger35/predictScalp")
df = project.fetch_runs_table().to_pandas()
df[['sys/id','best_test_loss']]

https://new-ui.neptune.ai/jettinger35/predictScalp/


Unnamed: 0,sys/id,best_test_loss
0,PRED-54,0.391573
1,PRED-53,0.423721
2,PRED-46,0.40543
3,PRED-43,0.390877
4,PRED-38,0.398955
5,PRED-35,0.403722
6,PRED-34,0.417772
7,PRED-32,0.408214
8,PRED-31,0.40969


In [None]:
from scipy.signal import get_window
a = get_window(('tukey', .25), nperseg)
a

In [None]:
'''
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        return loss
            
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
    test_loss /= num_batches
    print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")
    return test_loss

class NeuralNetwork(nn.Module):
    
    def __init__(self, layerOrderedDict):
        super().__init__()
        self.model = nn.Sequential(layerOrderedDict)
        
    def forward(self, x):
        return self.model(x)
    
    
# GIVEN A LIST OF LAYER SIZES MAKE AN ORDERED DICTIONARY FOR INITIALIZING A PYTORCH NET

def listToOrderedDict(sizeList):
    n = len(sizeList)
    tupleList = []
    for i in range(n - 1):
        tupleList.append(('bn%s' % str(i), nn.BatchNorm1d(sizeList[i])))
        tupleList.append(('l%s' % str(i), nn.Linear(sizeList[i], sizeList[i+1])))
        tupleList.append(('r%s' % str(i), nn.ReLU()))
    return OrderedDict(tupleList[:-1])

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
'''

'''
    layerSizeList = [trainXTensor.shape[1]] + hiddenLayerSizes + [trainYTensor.shape[1]]
    layerOrderedDict = sdm.listToOrderedDict(layerSizeList)
    model = sdm.NeuralNetwork(layerOrderedDict)
    '''