In [1]:
import matplotlib.pyplot as plt
%matplotlib widget 

import numpy as np
import mne
#mne.datasets.sample.data_path()

import torch
import eegCompressModels
import imp
imp.reload(eegCompressModels)

import neptune
from neptune_pytorch import NeptuneLogger

  import imp


In [2]:
def imageCompare(start, channel = 0, plotOption="both"):

    original = data[channel,start:start + numSampleInput]

    modelInput = np.reshape(data[:, start:start + numSampleInput], (inSize, -1), order='F').astype('float32').flatten()
    encoded = model.encoder(torch.tensor(modelInput))
    decoded = np.reshape(model.decoder(encoded).detach().numpy(), (nChannel, numSampleInput),order="C")[channel, start:start + numSampleInput]

    fig = plt.figure()
    if plotOption == "both":
        plt.plot(original, label='original')
        plt.plot(decoded, label='decoded')
        plt.legend()
    elif plotOption == "orig":
        plt.plot(original)
        plt.title('original')
    else: 
        plt.plot(decoded)
        plt.title('decoded')

    return fig

### EEG data

In [3]:
#raw = mne.io.read_raw_fif('./mne_data/MNE-sample-data/MEG/sample/sample_audvis_raw.fif')
raw = mne.io.read_raw_edf('/teamspace/uploads/ExampleLTMFiles/SVD001.edf')

Extracting EDF parameters from /teamspace/uploads/ExampleLTMFiles/SVD001.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  raw = mne.io.read_raw_edf('/teamspace/uploads/ExampleLTMFiles/SVD001.edf')


In [4]:
print(raw)
print(raw.info)

<RawEDF | SVD001.edf, 46 x 1276416 (4986.0 s), ~48 KiB, data not loaded>
<Info | 8 non-empty values
 bads: []
 ch_names: Fp1, F7, T7, P7, O1, F3, C3, P3, A1, Fz, Cz, Fp2, F8, T8, P8, ...
 chs: 46 EEG
 custom_ref_applied: False
 highpass: 0.0 Hz
 lowpass: 128.0 Hz
 meas_date: 2001-01-01 04:46:55 UTC
 nchan: 46
 projs: []
 sfreq: 256.0 Hz
 subject_info: <subject_info | his_id: SVD001>
>


In [5]:
chanList = range(0,19)
nChannel = len(chanList)

data = raw.get_data()[chanList] #eeg channels
print(data.shape)

(19, 1276416)


In [6]:
for i in range(nChannel):
    data[i,:] = (data[i,:] - np.mean(data[i,:])) / np.std(data[i,:])

### Random data

In [7]:
'''
data = np.random.random((1,100000))
data = data - np.mean(data)
data = data/np.std(data)
nChannel = data.shape[0]
'''

'\ndata = np.random.random((1,100000))\ndata = data - np.mean(data)\ndata = data/np.std(data)\nnChannel = data.shape[0]\n'

### Define Model, etc.

In [8]:
# Set in/out parameters
numSampleInput = 5
outSizeRatio = 1.0
inSize = nChannel * numSampleInput
outSize = int(inSize * outSizeRatio)
print(inSize, outSize)

# Construct the DataLoader
dataset = eegCompressModels.CustomDataset(data, numSampleInput)
batch_size = 32
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Make the model
encoderSizeList = [inSize, outSize]
decoderSizeList = [outSize, inSize]
encoderActivationList = [False]
decoderActivationList = [False]

model = eegCompressModels.AE(encoderSizeList, decoderSizeList, encoderActivationList, decoderActivationList)
print(model)
loss_function = torch.nn.MSELoss()

95 95
AE(
  (encoder): Sequential(
    (0): Linear(in_features=95, out_features=95, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=95, out_features=95, bias=True)
  )
)


In [9]:
run = neptune.init_run(
    project="jettinger35/eegCompress",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiIzMjFlMzY2MS1iOWZiLTRmZWEtOGMwNy0zOTVkMTljOGVjYTMifQ==",
)

npt_logger = NeptuneLogger(
    run=run,
    model=model)



[neptune] [info   ] Neptune initialized. Open in the app: https://app.neptune.ai/jettinger35/eegCompress/e/EEG-92


In [10]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.0)
#optimizer = torch.optim.Adam(model.parameters())#, lr = 1e-1, weight_decay = 1e-8)

totalEpoch = 0

In [None]:
epochs = 40
startPlot = 0

outputs = []
losses = []

for epoch in range(epochs):
	for (image, _) in loader:
		# Output of Autoencoder
		reconstructed = model(image.to(torch.float32))
		
		# Calculating the loss function
		loss = loss_function(reconstructed, image)
		
		# The gradients are set to zero,
		# the gradient is computed and stored.
		# .step() performs parameter update
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		
		# Storing the losses in a list for plotting
		losses.append(loss)
		outputs.append((epochs, image, reconstructed))
		run[npt_logger.base_namespace]["train/log_loss"].append(np.log(loss.item()))

	fig = imageCompare(startPlot)
	plt.title("Total Epoch: " + str(totalEpoch))
	run["fig"].append(fig)
	plt.close()
    totalEpoch = totalEpoch + 1

torch.save(model.state_dict(), 'savedModel')

In [None]:
list(model.parameters())[0].grad

In [None]:
for name, param in model.named_parameters():
    print((name, param.grad))

In [None]:
startPlot = 0
plt.show(imageCompare(startPlot))

### Misc