In [None]:
# from glob import glob
import pandas as pd
import os
import numpy as np
import json
from scipy import signal
import re
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.preprocessing import RobustScaler
from sklearn.preprocessing import QuantileTransformer
import numpy.random as npr
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math

In [4]:
# Dataset class
class EEGDataset(Dataset):

    # Constructor
    def __init__(self, data_dict, split='train', flatten = True):

        # --------------------------------------------------------------------------------------------------------
        # Prep object and raw data
        self.split = split
        self.flatten = flatten
             
        self.eeg, self.emb = data_dict[split]           
        # Compute size
        self.size = len(self.eeg)
    # Get size
    def __len__(self):
        return self.size
    # Get item
    def __getitem__(self, i):
        eeg_i = self.eeg[i]
        emb_i = self.emb[i]

        if self.flatten == True:
            eeg_i = eeg_i.flatten()
            emb_i = emb_i.flatten()
        to_return = (torch.from_numpy(eeg_i), torch.from_numpy(emb_i))
        return to_return

In [334]:
is_cuda = torch.cuda.is_available()
#print(is_cuda)
device = torch.device("cuda:0" if is_cuda else "cpu")
print(device)
max_epochs = 20

cuda:0


In [390]:
class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Linear(ninp, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp*2, ntoken)

        self.init_weights()
        
        self.fc1 = nn.Linear(ninp, ninp*2)
        self.fc2 = nn.Linear(ninp*2, ninp)
        self.fc3 = nn.Linear(ninp*2, ninp)
        self.bn1 = nn.BatchNorm1d(num_features = int(ninp*2))
        self.bn2 = nn.BatchNorm1d(num_features = int(ninp))

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
 
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

        
    def forward(self, src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask
    
    
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)
        #print(output.shape)
        dim0 = output.shape[0]
        h1 = self.bn1(self.fc1(output[:,-1,:]))#.view(-1, ninp*2)))
        #h1 = h1.view(dim0, -1)
        print('h1', h1.shape)
        mu, logvar = self.fc2(h1), self.fc3(h1)
        
        output = self.decoder(self.bn1(h1))#output[:,-1,:]))
        return output, mu, logvar

In [391]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [392]:
import pickle
with open('finWFT_300_3_o350_s60lf20hf80.pkl', 'rb') as pickle_in:  
    dataset = pickle.load(pickle_in) 

In [393]:
#model = nn.Transformer(d_model = 8, nhead = 8, num_encoder_layers = 6).to(device)
ntoken, ninp, nhead, nhid, nlayers, dropout = 768, 3400, 2, 200, 2, 0.5
model = TransformerModel(ntoken, ninp, nhead, nhid, nlayers, dropout).to(device)

In [394]:
#ex = torch.ones(10,32,4)
#src = Variable(ex, requires_grad = False)
#tgt = Variable(ex, requires_grad = False)
#ey = model.forward(src, tgt)
#src.shape
#ey[0].shape

In [395]:
bsz = 64
train_set = EEGDataset(dataset, split = 'train')
train_generator = DataLoader(train_set, batch_size=bsz, shuffle=True)
test_set = EEGDataset(dataset, split = 'test')
test_generator = DataLoader(test_set, batch_size=bsz, shuffle=True)
print(dataset['train'][1].shape)
'''
train_src = Variable(torch.from_numpy(dataset['train'][0]), requires_grad = False)
train_tgt = Variable(torch.from_numpy(dataset['train'][0]), requires_grad = False)
test_src = Variable(torch.from_numpy(dataset['test'][0]), requires_grad = False)
test_tgt = Variable(torch.from_numpy(dataset['test'][0]), requires_grad = False)
dataset['train'][0].shape, dataset['train'][1].shape
'''

#for i, (eeg, emb) in enumerate(train_generator):
 #   print(i, eeg.shape, emb.shape)

(33762, 3, 256)


"\ntrain_src = Variable(torch.from_numpy(dataset['train'][0]), requires_grad = False)\ntrain_tgt = Variable(torch.from_numpy(dataset['train'][0]), requires_grad = False)\ntest_src = Variable(torch.from_numpy(dataset['test'][0]), requires_grad = False)\ntest_tgt = Variable(torch.from_numpy(dataset['test'][0]), requires_grad = False)\ndataset['train'][0].shape, dataset['train'][1].shape\n"

In [396]:
#kl divergence
def loss_function(recon_y, y, mu, logvar):
    bce = F.mse_loss(recon_y, y)
    print("bce", bce)
    kld = -0.5*torch.sum(1+logvar-mu.pow(2)-logvar.exp())
    print("kld", kld)
    return bce+kld

In [397]:
lr = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [398]:
train_losses = []
test_losses = []
loss_func = nn.CrossEntropyLoss()

for epoch in range(max_epochs):
    
    losses = {"train": 0, "test": 0}
    counts = {"train": 0, "test": 0}
    
    lr *= 0.5**(epoch//5)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    
    model.train()
    torch.set_grad_enabled(True)
    
    for data_x, data_y in train_generator:
        #print(data_y.shape)
        #print(data_x.shape)
        
        #x2 = torch.stack([data_x for i in range(8)], axis = 2)
        #print(x2.shape)
        
        #x = x2.float().to(device)
        x = data_x.float().to(device)
        y = data_y.float().to(device)
        
        optimizer.zero_grad()
        #src = Variable(x, requires_grad = False)
        #print(src.size(1))
        #tgt = Variable(x, requires_grad = False)
        
        #output = model(x)
        
        #y = y.squeeze(1)
        
        #ndims = bsz*ntoken
        #loss = loss_func(output.view(-1, ndims), y) 
        
        recon_batch, mu, logvar = model(x)
        
        #print(recon_batch.shape, y.shape)
        loss = loss_function(recon_batch, y, mu, logvar)
        #print(loss, loss.item())
        
        loss.backward()
        optimizer.step()
        
        losses['train'] += loss.item()
        counts['train'] += 1
        
        
    
    for idx, (data_x, data_y) in enumerate(test_generator):
        x = data_x.float().to(device)
        y = data_y.float().to(device)
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(x)
        
        loss = loss_function(recon_batch, y, mu, logvar)
        
        losses['test'] += loss.item()
        counts['test'] += 1
        
        
    train_losses.append(losses['train']/counts['train'])
    test_losses.append(losses['test']/counts['test'])
    #print(train_losses, test_losses)
    
    print("Epoch[{}/{}] Training Loss: {}, Test Loss: {}".format(epoch+1, max_epochs, 
                                                                 train_losses[-1], test_losses[-1]))

h1 torch.Size([64, 6800])
bce tensor(22.6855, device='cuda:0', grad_fn=<MseLossBackward>)
kld tensor(55748.3281, device='cuda:0', grad_fn=<MulBackward0>)
h1 torch.Size([64, 6800])
bce tensor(22.1435, device='cuda:0', grad_fn=<MseLossBackward>)
kld tensor(444024.5000, device='cuda:0', grad_fn=<MulBackward0>)
h1 torch.Size([64, 6800])
bce tensor(72.1830, device='cuda:0', grad_fn=<MseLossBackward>)
kld tensor(3.7554e+16, device='cuda:0', grad_fn=<MulBackward0>)
h1 torch.Size([64, 6800])
bce tensor(42.6089, device='cuda:0', grad_fn=<MseLossBackward>)
kld tensor(2.5854e+18, device='cuda:0', grad_fn=<MulBackward0>)
h1 torch.Size([64, 6800])
bce tensor(93.1621, device='cuda:0', grad_fn=<MseLossBackward>)
kld tensor(2.9944e+30, device='cuda:0', grad_fn=<MulBackward0>)
h1 torch.Size([64, 6800])
bce tensor(nan, device='cuda:0', grad_fn=<MseLossBackward>)
kld tensor(nan, device='cuda:0', grad_fn=<MulBackward0>)
h1 torch.Size([64, 6800])
bce tensor(nan, device='cuda:0', grad_fn=<MseLossBackward>)


KeyboardInterrupt: 