In [7]:
%matplotlib notebook
import numpy as np
import matplotlib.pyplot as plt

from scipy.signal import spectrogram, stft, istft, check_NOLA

#import ray
#ray.init(include_dashboard=True, num_cpus = 8, dashboard_host='0.0.0.0')

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torchvision.transforms import ToTensor
from torch.utils.tensorboard import SummaryWriter

import os
api_token = os.environ.get('NEPTUNE_API_TOKEN')

plt.style.use('ggplot')

# UTILITY FUNCTIONS

In [8]:
subsampleFreq = 64   # FINAL FREQUENCY IN HERTZ AFTER SUBSAMPLING
secondsInWindow = 1.
nperseg = subsampleFreq * secondsInWindow
noverlap = nperseg - 1
window = ('tukey', .25)

In [9]:
# CONVERT STFT FROM R,THETA TO COMPLEX
# dim(z) = (# timesteps, # freq bins x 2 (2 reals = 1 complex))

def rThetaToComplex(z):
    rows, cols = z.shape
    shortTermFourier = np.zeros((rows, cols // 2), dtype=np.csingle)
    for i in range(rows):
        for k in range(cols // 2):
            r = z[i,k]
            theta = z[i, (k + cols // 2)]
            shortTermFourier[i,k] =  r * np.exp(complex(0, theta))
    return shortTermFourier.transpose() # dim = (# freq bins, # timepoints)

# CONVERT REAL STFT TO COMPLEX STFT, INVERT TO GET THE ISTFT (I.E. TIME SERIES), THEN PLOT

def realSTFTtoTimeSeries(realSTFT):
    shortTermFourierComplex = rThetaToComplex(realSTFT)
    times, inverseShortFourier = istft(shortTermFourierComplex, 
                                       fs=subsampleFreq, 
                                       window=window, 
                                       nperseg=nperseg, 
                                       noverlap=noverlap)
    return times, inverseShortFourier

# LOAD NUMPY ARRAYS

In [10]:
arraySavePath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/trainTestRTheta.npz'
modelPath = '/blue/gkalamangalam/jmark.ettinger/predictScalp/pytorchModels/model1.pth'

npzfile = np.load(arraySavePath)
x_trainRTheta = npzfile['x_trainRTheta']
x_validRTheta = npzfile['x_validRTheta'] 
y_trainRTheta = npzfile['y_trainRTheta'] 
y_validRTheta = npzfile['y_validRTheta']

_,nY = y_validRTheta.shape

In [11]:
batch_size = 64

trainXTensor = torch.Tensor(x_trainRTheta)
trainYTensor = torch.Tensor(y_trainRTheta)

trainDataset = TensorDataset(trainXTensor,trainYTensor)
trainDataLoader = DataLoader(trainDataset,batch_size=batch_size, shuffle=True)

validXTensor = torch.Tensor(x_validRTheta)
validYTensor = torch.Tensor(y_validRTheta)

validDataset = TensorDataset(validXTensor,validYTensor)
validDataLoader = DataLoader(validDataset,batch_size=batch_size, shuffle=True)


print("train: ")
for X, y in trainDataLoader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break
    
print("\ntest: ")
for X, y in validDataLoader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

train: 
Shape of X [N, C, H, W]: torch.Size([64, 5742])
Shape of y: torch.Size([64, 66]) torch.float32

test: 
Shape of X [N, C, H, W]: torch.Size([64, 5742])
Shape of y: torch.Size([64, 66]) torch.float32


In [14]:
loadFlag = False

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(5742, 512),
            nn.ReLU(),
            nn.Linear(512, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 66)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

class EarlyStopper:
    def __init__(self, patience, min_delta):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

model = NeuralNetwork()
if loadFlag == True:
    model.load_state_dict(torch.load(modelPath))
    
model = model.to(device)
print(model)

Using cuda device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=5742, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=100, bias=True)
    (3): ReLU()
    (4): Linear(in_features=100, out_features=100, bias=True)
    (5): ReLU()
    (6): Linear(in_features=100, out_features=100, bias=True)
    (7): ReLU()
    (8): Linear(in_features=100, out_features=100, bias=True)
    (9): ReLU()
    (10): Linear(in_features=100, out_features=66, bias=True)
  )
)


In [15]:
def train(dataloader, model, loss_fn, optimizer, epoch):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
        return loss
            
def test(dataloader, model, loss_fn, epoch):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
    test_loss /= num_batches
    print(f"Test Error: \n Avg loss: {test_loss:>8f} \n")
    return test_loss

In [None]:
epochs = 5000
patience = 50
min_delta = 0

run = neptune.init_run(
    project='jettinger35/test',
    api_token=api_token,  
)

#writer = SummaryWriter()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train(trainDataLoader, model, loss_fn, optimizer, t)
    test_loss = test(validDataLoader, model, loss_fn, t)
    #writer.add_scalars('loss', {'train': train_loss, 'test': test_loss}, t)
    #writer.flush()
    
    run["train/loss"].append(train_loss)
    run["test/loss"].append(test_loss)
    
    if early_stopper.early_stop(test_loss):   
        print("Early stopping invoked....")
        break
    
#writer.close()
run.stop()
print("Done!")

https://new-ui.neptune.ai/jettinger35/test/e/TES-2
Epoch 1
-------------------------------
loss: 23.332014  [   64/89440]
Test Error: 
 Avg loss: 18.194602 

Epoch 2
-------------------------------
loss: 18.210938  [   64/89440]
Test Error: 
 Avg loss: 18.114377 

Epoch 3
-------------------------------
loss: 24.251654  [   64/89440]
Test Error: 
 Avg loss: 18.049251 

Epoch 4
-------------------------------
loss: 17.322168  [   64/89440]
Test Error: 
 Avg loss: 17.989240 

Epoch 5
-------------------------------
loss: 22.910042  [   64/89440]
Test Error: 
 Avg loss: 17.845835 

Epoch 6
-------------------------------
loss: 19.224859  [   64/89440]
Test Error: 
 Avg loss: 17.741329 

Epoch 7
-------------------------------
loss: 18.152977  [   64/89440]
Test Error: 
 Avg loss: 17.624545 

Epoch 8
-------------------------------
loss: 20.767334  [   64/89440]
Test Error: 
 Avg loss: 17.474685 

Epoch 9
-------------------------------
loss: 21.231815  [   64/89440]
Test Error: 
 Avg loss

Test Error: 
 Avg loss: 7.372544 

Epoch 78
-------------------------------
loss: 6.147372  [   64/89440]
Test Error: 
 Avg loss: 7.555678 

Epoch 79
-------------------------------
loss: 6.935519  [   64/89440]
Test Error: 
 Avg loss: 7.371442 

Epoch 80
-------------------------------
loss: 6.413792  [   64/89440]
Test Error: 
 Avg loss: 7.381949 

Epoch 81
-------------------------------
loss: 8.227101  [   64/89440]
Test Error: 
 Avg loss: 7.381698 

Epoch 82
-------------------------------
loss: 5.192735  [   64/89440]
Test Error: 
 Avg loss: 7.369535 

Epoch 83
-------------------------------
loss: 6.523130  [   64/89440]
Test Error: 
 Avg loss: 7.373960 

Epoch 84
-------------------------------
loss: 6.272357  [   64/89440]
Test Error: 
 Avg loss: 7.346574 

Epoch 85
-------------------------------
loss: 5.255451  [   64/89440]
Test Error: 
 Avg loss: 7.343073 

Epoch 86
-------------------------------
loss: 7.058350  [   64/89440]
Test Error: 
 Avg loss: 7.393526 

Epoch 87
--

Test Error: 
 Avg loss: 7.088138 

Epoch 155
-------------------------------
loss: 5.231828  [   64/89440]
Test Error: 
 Avg loss: 7.088898 

Epoch 156
-------------------------------
loss: 8.538354  [   64/89440]
Test Error: 
 Avg loss: 7.089791 

Epoch 157
-------------------------------
loss: 11.791036  [   64/89440]
Test Error: 
 Avg loss: 7.062018 

Epoch 158
-------------------------------
loss: 9.617811  [   64/89440]
Test Error: 
 Avg loss: 7.077949 

Epoch 159
-------------------------------
loss: 9.427655  [   64/89440]
Test Error: 
 Avg loss: 7.171526 

Epoch 160
-------------------------------
loss: 5.630604  [   64/89440]
Test Error: 
 Avg loss: 7.148862 

Epoch 161
-------------------------------
loss: 5.273030  [   64/89440]
Test Error: 
 Avg loss: 7.082390 

Epoch 162
-------------------------------
loss: 7.114505  [   64/89440]
Test Error: 
 Avg loss: 7.049162 

Epoch 163
-------------------------------
loss: 13.890153  [   64/89440]
Test Error: 
 Avg loss: 7.340038 



Test Error: 
 Avg loss: 6.732178 

Epoch 232
-------------------------------
loss: 8.871037  [   64/89440]
Test Error: 
 Avg loss: 6.729035 

Epoch 233
-------------------------------
loss: 4.928020  [   64/89440]
Test Error: 
 Avg loss: 6.752430 

Epoch 234
-------------------------------
loss: 7.753354  [   64/89440]
Test Error: 
 Avg loss: 6.834114 

Epoch 235
-------------------------------
loss: 8.210329  [   64/89440]
Test Error: 
 Avg loss: 6.698081 

Epoch 236
-------------------------------
loss: 8.446198  [   64/89440]
Test Error: 
 Avg loss: 6.817297 

Epoch 237
-------------------------------
loss: 5.719049  [   64/89440]
Test Error: 
 Avg loss: 6.708968 

Epoch 238
-------------------------------
loss: 6.221739  [   64/89440]
Test Error: 
 Avg loss: 6.683065 

Epoch 239
-------------------------------
loss: 4.683679  [   64/89440]
Test Error: 
 Avg loss: 6.721447 

Epoch 240
-------------------------------
loss: 6.367719  [   64/89440]
Test Error: 
 Avg loss: 6.681021 

Ep

Test Error: 
 Avg loss: 6.492410 

Epoch 309
-------------------------------
loss: 7.052158  [   64/89440]
Test Error: 
 Avg loss: 6.614603 

Epoch 310
-------------------------------
loss: 7.066903  [   64/89440]
Test Error: 
 Avg loss: 6.473946 

Epoch 311
-------------------------------
loss: 5.481149  [   64/89440]
Test Error: 
 Avg loss: 6.466201 

Epoch 312
-------------------------------
loss: 5.037040  [   64/89440]
Test Error: 
 Avg loss: 6.505214 

Epoch 313
-------------------------------
loss: 9.470856  [   64/89440]
Test Error: 
 Avg loss: 6.453732 

Epoch 314
-------------------------------
loss: 6.175179  [   64/89440]
Test Error: 
 Avg loss: 6.573549 

Epoch 315
-------------------------------
loss: 9.396106  [   64/89440]
Test Error: 
 Avg loss: 6.453189 

Epoch 316
-------------------------------
loss: 5.825189  [   64/89440]
Test Error: 
 Avg loss: 6.448159 

Epoch 317
-------------------------------
loss: 5.210443  [   64/89440]
Test Error: 
 Avg loss: 6.451455 

Ep

Test Error: 
 Avg loss: 6.434169 

Epoch 386
-------------------------------
loss: 10.415342  [   64/89440]
Test Error: 
 Avg loss: 6.394022 

Epoch 387
-------------------------------
loss: 6.368142  [   64/89440]
Test Error: 
 Avg loss: 6.278430 

Epoch 388
-------------------------------
loss: 5.318780  [   64/89440]
Test Error: 
 Avg loss: 6.315503 

Epoch 389
-------------------------------
loss: 6.178883  [   64/89440]
Test Error: 
 Avg loss: 6.279296 

Epoch 390
-------------------------------
loss: 8.435814  [   64/89440]
Test Error: 
 Avg loss: 6.482000 

Epoch 391
-------------------------------
loss: 5.679610  [   64/89440]
Test Error: 
 Avg loss: 6.301472 

Epoch 392
-------------------------------
loss: 5.856975  [   64/89440]
Test Error: 
 Avg loss: 6.286811 

Epoch 393
-------------------------------
loss: 4.410472  [   64/89440]
Test Error: 
 Avg loss: 6.277825 

Epoch 394
-------------------------------
loss: 6.705330  [   64/89440]
Test Error: 
 Avg loss: 6.289248 

E

Test Error: 
 Avg loss: 6.178534 

Epoch 463
-------------------------------
loss: 6.453291  [   64/89440]
Test Error: 
 Avg loss: 6.259959 

Epoch 464
-------------------------------
loss: 5.933983  [   64/89440]
Test Error: 
 Avg loss: 6.267306 

Epoch 465
-------------------------------
loss: 4.901986  [   64/89440]
Test Error: 
 Avg loss: 6.197164 

Epoch 466
-------------------------------
loss: 5.743174  [   64/89440]
Test Error: 
 Avg loss: 6.235100 

Epoch 467
-------------------------------
loss: 4.630851  [   64/89440]
Test Error: 
 Avg loss: 6.136221 

Epoch 468
-------------------------------
loss: 5.813499  [   64/89440]
Test Error: 
 Avg loss: 6.154040 

Epoch 469
-------------------------------
loss: 5.636886  [   64/89440]
Test Error: 
 Avg loss: 6.151587 

Epoch 470
-------------------------------
loss: 7.924644  [   64/89440]
Test Error: 
 Avg loss: 6.261087 

Epoch 471
-------------------------------
loss: 5.722291  [   64/89440]
Test Error: 
 Avg loss: 6.129488 

Ep

Test Error: 
 Avg loss: 6.337026 

Epoch 540
-------------------------------
loss: 4.788347  [   64/89440]
Test Error: 
 Avg loss: 6.315117 

Epoch 541
-------------------------------
loss: 4.958446  [   64/89440]
Test Error: 
 Avg loss: 6.143491 

Epoch 542
-------------------------------
loss: 6.479766  [   64/89440]
Test Error: 
 Avg loss: 6.187895 

Epoch 543
-------------------------------
loss: 5.037448  [   64/89440]
Test Error: 
 Avg loss: 6.174322 

Epoch 544
-------------------------------
loss: 6.478382  [   64/89440]
Test Error: 
 Avg loss: 6.116248 

Epoch 545
-------------------------------
loss: 5.792099  [   64/89440]
Test Error: 
 Avg loss: 6.145540 

Epoch 546
-------------------------------
loss: 5.045148  [   64/89440]
Test Error: 
 Avg loss: 6.142982 

Epoch 547
-------------------------------
loss: 10.117208  [   64/89440]
Test Error: 
 Avg loss: 6.530738 

Epoch 548
-------------------------------
loss: 5.079580  [   64/89440]
Test Error: 
 Avg loss: 6.235761 

E

Test Error: 
 Avg loss: 6.096280 

Epoch 617
-------------------------------
loss: 7.088270  [   64/89440]
Test Error: 
 Avg loss: 5.970308 

Epoch 618
-------------------------------
loss: 5.314685  [   64/89440]
Test Error: 
 Avg loss: 6.082989 

Epoch 619
-------------------------------
loss: 5.409071  [   64/89440]
Test Error: 
 Avg loss: 5.963160 

Epoch 620
-------------------------------
loss: 4.429886  [   64/89440]
Test Error: 
 Avg loss: 5.993743 

Epoch 621
-------------------------------
loss: 4.522482  [   64/89440]
Test Error: 
 Avg loss: 6.018512 

Epoch 622
-------------------------------
loss: 7.854012  [   64/89440]
Test Error: 
 Avg loss: 6.051407 

Epoch 623
-------------------------------
loss: 6.865088  [   64/89440]
Test Error: 
 Avg loss: 6.223365 

Epoch 624
-------------------------------
loss: 5.497069  [   64/89440]
Test Error: 
 Avg loss: 6.003786 

Epoch 625
-------------------------------
loss: 5.868363  [   64/89440]
Test Error: 
 Avg loss: 6.172258 

Ep

In [18]:
torch.save(model.state_dict(), modelPath)

# Plot results of fit

In [68]:
# PLOT PREDICTION VERSUS TRUTH

trainPlotFlag = False
    
if trainPlotFlag:
    x = trainXTensor
    y = y_trainRTheta
    trainTitle = 'train'
else:
    x = validXTensor
    y = y_validRTheta
    trainTitle = 'valididation'
    

x = validXTensor.to(device)
freqPredict = model(x).cpu().detach().numpy()

_, yPred = realSTFTtoTimeSeries(freqPredict)
_, yTrue = realSTFTtoTimeSeries(y)

lossTemp = loss_fn(torch.tensor(yPred), torch.tensor(yTrue)).item()
title = 'PyTorch: ' + trainTitle + ' (mse: %s)' % str(lossTemp)
plt.figure()
plt.plot(yPred, label='predict')
plt.plot(yTrue, label='true')
plt.legend()
plt.title(title)
plt.show()

<IPython.core.display.Javascript object>

# SCRATCH BELOW

In [6]:
import neptune

# Create a Neptune run object
run = neptune.init_run(
    project='jettinger35/test',
    api_token=api_token,  
)

# Track metadata and hyperparameters by assigning them to the run
run["JIRA"] = "NPT-952"
run["algorithm"] = "ConvNet"

PARAMS = {
    "batch_size": 64,
    "dropout": 0.2,
    "learning_rate": 0.001,
    "optimizer": "Adam",
}
run["parameters"] = PARAMS

# Track the training process by logging your training metrics
for epoch in range(10):
    run["train/accuracy"].append(epoch * 0.6)  
    run["train/loss"].append(epoch * 0.4)

# Record the final results
run["f1_score"] = 0.66

# Stop the connection and synchronize the data with the Neptune servers
run.stop()


https://new-ui.neptune.ai/jettinger35/test/e/TES-1
Shutting down background jobs, please wait a moment...
Done!
Waiting for the remaining 27 operations to synchronize with Neptune. Do not kill this process.
All 27 operations synced, thanks for waiting!
Explore the metadata in the Neptune app:
https://new-ui.neptune.ai/jettinger35/test/e/TES-1/metadata
