# Variational Auto Encoder Testing Framework
This notebook will attempt to augment CNNs with VAEs. We will train a VAE on the training set and then take the model to produce signals of reduced dimensionality that will be feeded into a CNN.

---



In [0]:
!kill -9 -1

##Initialization
Google Drive access, PyTorch, etc.

In [1]:
!apt-get install -y -qq software-properties-common python-software-properties module-init-tools
!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!apt-get update -qq 2>&1 > /dev/null
!apt-get -y install -qq google-drive-ocamlfuse fuse
from google.colab import auth
auth.authenticate_user()
from oauth2client.client import GoogleCredentials
creds = GoogleCredentials.get_application_default()
import getpass
!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}

Please, open the following URL in a web browser: https://accounts.google.com/o/oauth2/auth?client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive&response_type=code&access_type=offline&approval_prompt=force
··········
Please, open the following URL in a web browser: https://accounts.google.com/o/oauth2/auth?client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive&response_type=code&access_type=offline&approval_prompt=force
Please enter the verification code: Access token retrieved correctly.


In [0]:
!mkdir -p drive
!google-drive-ocamlfuse drive

In [0]:
# http://pytorch.org/
from os import path
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())

accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'

!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.3.0.post4-{platform}-linux_x86_64.whl torchvision

##Imports

In [0]:
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.cuda
from torch.utils.data import Dataset
from torch.autograd import Variable
from scipy import stats

##Classes

###EEG Dataset

In [0]:
class EEGDataset(Dataset):
  """EEG dataset."""
  
  def __init__(self, x, y, transform=None):
    """
    Args:
      x (numpy array): Input data of shape 
                       num_trials x num_electrodes x num_time_bins.
      y (numpy array): Output data of shape num_trials x 1.
      transform (callable, optional): Optional transform to be applied.
    """
    self.x = x
    self.y = y
    self.transform = transform
    
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    x_sample = torch.from_numpy(self.x[idx])
    y_sample = torch.IntTensor([int(self.y[idx])])
    
    if self.transform:
      pass #FIXME
    
    return x_sample, y_sample

###EEG Minimal Container

In [0]:
class EEGMinimalContainer():
  """EEG container for training and testing datasets."""
  
  def __init__(self, data_dir, train_subject=None, test_subject=None, 
               remove_eog_channels=True, seed=42):
    """
    Args:
      data_dir (string): Path to all A0iT_slice.mat files for i in [1, 9].
      train_subject(int): Subject to train on. If None, train on all.
      test_subject(int): Subject to test on. If None, train on all except for
                         train_subject. Only used if train_subject is not None.
    """
    self.X_train = None
    self.y_train = None
    self.X_test = None
    self.y_test = None
    self.train_dataset = None
    self.test_dataset = None
    np.random.seed(seed)
    
    if train_subject is None:
      # Step 1: Append all of the input and output data together
      X = None
      y = None
      end = np.empty(9)
      for i in np.arange(9):
        A0iT = h5py.File(data_dir + ('/A0%dT_slice.mat' % (i+1)), 'r')
        X_temp = np.copy(A0iT['image'])
        y_temp = np.copy(A0iT['type'])
        y_temp = y_temp[0,0:X_temp.shape[0]:1]
        y_temp = np.asarray(y_temp, dtype=np.int32)
        X = X_temp if X is None else np.append(X, X_temp, axis=0)
        y = y_temp if y is None else np.append(y, y_temp, axis=0)
        end[i] = X_temp.shape[0] if i == 0 else X_temp.shape[0] + end[i-1]
      X = np.expand_dims(X, axis=1)
      y -= 769
      # Step 2: Remove the EOG
      if remove_eog_channels:
        X = X[:, :, 0:22, :] 
      # Step 3: Remove NaN trials
      remove_list = []
      for i in range(len(X)):
        if np.isnan(X[i]).any():
          remove_list.append(i)
      for trial_row in remove_list:
        end[end > trial_row] -= 1
      X = np.delete(X, remove_list, axis=0)
      y = np.delete(y, remove_list, axis=0)
      # Normalize DATASET to [0, 1]
      Xmax = np.nanmax(X)
      Xmin = np.nanmin(X)
      X = (X - Xmin) / (Xmax - Xmin)
      # Step 4: Generate an train/test split
      remove_list = []
      self.X_test = {}
      self.y_test = {}
      self.test_dataset = {}
      sloc = 0
      for i, eloc in enumerate(end, 1):
        t_list = np.random.choice(np.arange(sloc, eloc), 50, replace=False)
        t_list = t_list.astype(int)
        self.X_test[str(i)] = X[t_list, :, :, :]
        self.y_test[str(i)] = y[t_list]
        self.test_dataset[str(i)] = EEGDataset(X[t_list, :, :, :], y[t_list])
        remove_list = remove_list + t_list.tolist()
        sloc = eloc
      self.X_train = np.delete(X, remove_list, axis=0)
      self.y_train = np.delete(y, remove_list, axis=0)
      self.train_dataset = EEGDataset(self.X_train, self.y_train)
      
      print('EEGContainer X_train: ' + str(self.X_train.shape))
      print('EEGContainer y_train: ' + str(self.y_train.shape))
      for i in range(1, 10):
        print(('EEGContainer X_test%d: ' %i) + str(self.X_test[str(i)].shape))
        print(('EEGContainer y_test%d: ' %i) + str(self.y_test[str(i)].shape))
    
    else:
      pass #FIXME

###Convolutional Neural Network

In [0]:
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Conv2d( 1, 16, (1, 11), stride=(1, 1), padding=0)
    self.conv2 = nn.Conv2d(16, 16, (22, 1), stride=(1, 1), padding=0)
    self.conv3 = nn.Conv2d( 1, 16, (1, 11), stride=(1, 1), padding=0)
    self.conv4 = nn.Conv2d(16, 16, (16, 1), stride=(1, 1), padding=0)
    self.conv5 = nn.Conv2d( 1, 16, 3, stride=1, padding=1)
    self.conv6 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
    self.conv7 = nn.Conv2d(32, 64, 3, stride=1, padding=1)
    self.fc1 = nn.Linear(64 * 2 * 10, 200)
    self.fc2 = nn.Linear(200, 100)
    self.fc3 = nn.Linear(100, 4)
  
  def forward(self, x):
    dropval = 0.7
    x = F.dropout2d(F.relu(self.conv1(x)), p=0.2)
    x = F.dropout2d(F.relu(self.conv2(x)), p=dropval)
    x = x.permute(0, 2, 1, 3)
    x = F.max_pool2d(x, (1, 3), (1, 3))
    x = F.dropout2d(F.relu(self.conv3(x)), p=dropval)
    x = F.dropout2d(F.relu(self.conv4(x)), p=dropval)
    x = x.permute(0, 2, 1, 3)
    x = F.max_pool2d(x, (1, 4), (1, 4))
    x = F.dropout2d(F.relu(self.conv5(x)), p=dropval)
    x = F.max_pool2d(x, 2, 2)
    x = F.dropout2d(F.relu(self.conv6(x)), p=dropval)
    x = F.max_pool2d(x, 2, 2)
    x = F.dropout2d(F.relu(self.conv7(x)), p=dropval)
    x = F.max_pool2d(x, 2, 2)
    x = x.view(-1, 64 * 2 * 10)
    x = F.dropout(F.relu(self.fc1(x)), p=dropval)
    x = F.dropout(F.relu(self.fc2(x)), p=dropval)
    x = self.fc3(x)
    return x

###Variational Auto Encoder

In [0]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        ZDIMS = 20

        # ENCODER
        # 22 x time_batch input pixels, 400 outputs
        self.fc1 = nn.Linear(22000, 400)
        # rectified linear unit layer from 400 to 400
        # max(0, x)
        self.relu = nn.ReLU()
        self.fc21 = nn.Linear(400, ZDIMS)  # mu layer
        self.fc22 = nn.Linear(400, ZDIMS)  # logvariance layer
        # this last layer bottlenecks through ZDIMS connections

        # DECODER
        # from bottleneck to hidden 400
        self.fc3 = nn.Linear(ZDIMS, 400)
        # from hidden 400 to 22 x time_batch outputs
        self.fc4 = nn.Linear(400, 22000)
        self.sigmoid = nn.Sigmoid()

    def encode(self, x: Variable) -> (Variable, Variable):
        """Input vector x -> fully connected 1 -> ReLU -> (fully connected
        21, fully connected 22)

        Parameters
        ----------
        x : [128, 784] matrix; 128 digits of 28x28 pixels each

        Returns
        -------

        (mu, logvar) : ZDIMS mean units one for each latent dimension, ZDIMS
            variance units one for each latent dimension

        """

        # h1 is [128, 400]
        h1 = self.relu(self.fc1(x))  # type: Variable
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu: Variable, logvar: Variable) -> Variable:
        """THE REPARAMETERIZATION IDEA:

        For each training sample (we get 128 batched at a time)

        - take the current learned mu, stddev for each of the ZDIMS
          dimensions and draw a random sample from that distribution
        - the whole network is trained so that these randomly drawn
          samples decode to output that looks like the input
        - which will mean that the std, mu will be learned
          *distributions* that correctly encode the inputs
        - due to the additional KLD term (see loss_function() below)
          the distribution will tend to unit Gaussians

        Parameters
        ----------
        mu : [128, ZDIMS] mean matrix
        logvar : [128, ZDIMS] variance matrix

        Returns
        -------

        During training random sample from the learned ZDIMS-dimensional
        normal distribution; during inference its mean.

        """

        if self.training:
            # multiply log variance with 0.5, then in-place exponent
            # yielding the standard deviation
            std = logvar.mul(0.5).exp_()  # type: Variable
            # - std.data is the [128,ZDIMS] tensor that is wrapped by std
            # - so eps is [128,ZDIMS] with all elements drawn from a mean 0
            #   and stddev 1 normal distribution that is 128 samples
            #   of random ZDIMS-float vectors
            eps = Variable(std.data.new(std.size()).normal_())
            # - sample from a normal distribution with standard
            #   deviation = std and mean = mu by multiplying mean 0
            #   stddev 1 sample with desired std and mu, see
            #   https://stats.stackexchange.com/a/16338
            # - so we have 128 sets (the batch) of random ZDIMS-float
            #   vectors sampled from normal distribution with learned
            #   std and mu for the current input
            return eps.mul(std).add_(mu)

        else:
            # During inference, we simply spit out the mean of the
            # learned distribution for the current input.  We could
            # use a random sample from the distribution, but mu of
            # course has the highest probability.
            return mu

    def decode(self, z: Variable) -> Variable:
        h3 = self.relu(self.fc3(z))
        return self.sigmoid(self.fc4(h3))

    def forward(self, x: Variable) -> (Variable, Variable, Variable):
        mu, logvar = self.encode(x.view(-1, 22000))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [0]:
# Reconstruction + KL divergence losses summed over all elements and batch
def vae_loss_function(recon_x, x, mu, logvar):
  
  BCE = F.binary_cross_entropy(recon_x, x.view(-1, 22000), size_average=False)

  # see Appendix B from VAE paper:
  # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
  # https://arxiv.org/abs/1312.6114
  # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
  #print(logvar)
  #print(logvar.exp)
  KLD = 0#-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  
  return BCE + KLD

##Setup

In [0]:
data_dir = 'drive/ee239as/project_datasets'
batch_size = 31

In [72]:
EEGset = EEGMinimalContainer(data_dir)

train_loader = torch.utils.data.DataLoader(EEGset.train_dataset, batch_size=batch_size, shuffle=True)

test_loader = {}
for i in range(1, 10):
  test_loader[str(i)] = torch.utils.data.DataLoader(EEGset.test_dataset[str(i)], batch_size=1, shuffle=False)

EEGContainer X_train: (2108, 1, 22, 1000)
EEGContainer y_train: (2108,)
EEGContainer X_test1: (50, 1, 22, 1000)
EEGContainer y_test1: (50,)
EEGContainer X_test2: (50, 1, 22, 1000)
EEGContainer y_test2: (50,)
EEGContainer X_test3: (50, 1, 22, 1000)
EEGContainer y_test3: (50,)
EEGContainer X_test4: (50, 1, 22, 1000)
EEGContainer y_test4: (50,)
EEGContainer X_test5: (50, 1, 22, 1000)
EEGContainer y_test5: (50,)
EEGContainer X_test6: (50, 1, 22, 1000)
EEGContainer y_test6: (50,)
EEGContainer X_test7: (50, 1, 22, 1000)
EEGContainer y_test7: (50,)
EEGContainer X_test8: (50, 1, 22, 1000)
EEGContainer y_test8: (50,)
EEGContainer X_test9: (50, 1, 22, 1000)
EEGContainer y_test9: (50,)


In [0]:
num_epochs = 200
learning_rate = 1e-5

use_cuda = True

net = CNN()
model = VAE()
criterion = nn.CrossEntropyLoss()

In [0]:
if use_cuda and torch.cuda.is_available():
  net.cuda()
  model.cuda()

optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=1e-3)
vae_optimizer = optim.Adam(model.parameters(), lr=5e-3)

##Training

###VAE Training Functions

In [0]:
def vae_train(epoch):
    # toggle model to train mode
    model.train()
    train_loss = 0
    # in the case of MNIST, len(train_loader.dataset) is 60000
    # each `data` is of BATCH_SIZE samples and has shape [128, 1, 28, 28]
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data.type(torch.FloatTensor))
        if use_cuda and torch.cuda.is_available():
            data = data.cuda()
        vae_optimizer.zero_grad()

        # push whole batch of data through VAE.forward() to get recon_loss
        recon_batch, mu, logvar = model(data)
        # calculate scalar loss
        loss = vae_loss_function(recon_batch, data, mu, logvar)
        # calculate the gradient of the loss w.r.t. the graph leaves
        # i.e. input variables -- by the power of pytorch!
        loss.backward()
        train_loss += loss.data[0]
        vae_optimizer.step()
        #if batch_idx % 17 == 0:
        #    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #        epoch, batch_idx * len(data), len(train_loader.dataset),
        #        100. * batch_idx / len(train_loader),
        #        loss.data[0] / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [0]:
def vae_test(epoch):
    # toggle model to test / inference mode
    model.eval()
    test_loss_list = []

    # each data is of batch_size samples
    print('Test set loss: [', end='')
    for subject in range(9):
      test_loss = 0
      for i, (data, _) in enumerate(test_loader[str(subject+1)]):
          data = data.type(torch.FloatTensor)
          if use_cuda and torch.cuda.is_available():
              # make sure this lives on the GPU
              data = data.cuda()

          # we're only going to infer, so no autograd at all required: volatile=True
          data = Variable(data, volatile=True)
          recon_batch, mu, logvar = model(data)
          test_loss += vae_loss_function(recon_batch, data, mu, logvar).data[0]
          if i == 0:
            n = min(data.size(0), 8)
            # for the first 128 batch of the epoch, show the first 8 input digits
            # with right below them the reconstructed output digits
            comparison = torch.cat([data[:n],
                                    recon_batch.view(1, 1, 22, 1000)[:n]])

      test_loss /= len(test_loader[str(subject+1)].dataset)
      print('{:.4f}, '.format(test_loss), end='')
    print(']')

###Train the VAE

In [77]:
for epoch in range(1, 101):
    vae_train(epoch)
    vae_test(epoch)

====> Epoch: 1 Average loss: 15323.9618
Test set loss: [15246.0625, 15246.5665, 15245.6041, 15243.8028, 15247.0920, 15241.9612, 15244.7264, 15249.6363, 15245.2642, ]
====> Epoch: 2 Average loss: 15245.2052
Test set loss: [15245.2591, 15245.9887, 15244.8032, 15242.9437, 15246.3472, 15240.8340, 15244.0160, 15249.2621, 15244.5770, ]
====> Epoch: 3 Average loss: 15245.0519
Test set loss: [15245.2963, 15246.1561, 15245.0465, 15242.8012, 15246.8002, 15240.7036, 15243.9282, 15249.4593, 15244.5995, ]
====> Epoch: 4 Average loss: 15245.3874
Test set loss: [15245.0232, 15246.1940, 15245.1352, 15242.9194, 15246.7922, 15240.0490, 15243.7722, 15249.3569, 15244.3540, ]
====> Epoch: 5 Average loss: 15245.9726
Test set loss: [15245.1778, 15246.5016, 15245.3593, 15243.3317, 15246.9329, 15240.6438, 15243.8772, 15249.0929, 15244.6103, ]
====> Epoch: 6 Average loss: 15244.8590
Test set loss: [15243.4200, 15245.4454, 15244.4054, 15241.9279, 15245.4810, 15238.7581, 15242.8403, 15247.4583, 15243.4116, ]
====

15237.8050, 15231.1674, ]
====> Epoch: 21 Average loss: 15234.1996
Test set loss: [15228.4227, 15241.5238, 15239.2915, 15236.0789, 15236.1554, 15226.7547, 15236.1439, 15237.4322, 15231.5660, ]
====> Epoch: 22 Average loss: 15231.7785
Test set loss: [15225.7881, 15239.4371, 15236.9687, 15233.7883, 15233.7872, 15224.1408, 15233.8068, 15234.8733, 15228.0485, ]
====> Epoch: 23 Average loss: 15231.5321
Test set loss: [15226.4624, 15240.7956, 15238.2801, 15235.0404, 15235.7076, 15224.7097, 15235.4093, 15235.4653, 15229.3145, ]
====> Epoch: 24 Average loss: 15230.5403
Test set loss: [15223.6538, 15238.5712, 15236.3805, 15233.1211, 15232.7090, 15222.4164, 15233.1111, 15233.1070, 15226.9466, ]
====> Epoch: 25 Average loss: 15229.5935
Test set loss: [15223.1361, 15238.5234, 15236.0137, 15232.8103, 15232.2336, 15221.2180, 15232.6222, 15232.5611, 15225.8651, ]
====> Epoch: 26 Average loss: 15230.0917
Test set loss: [15225.8488, 15240.5160, 15238.4929, 15234.6490, 15234.0769, 15223.7327, 15234.2369

15230.3256, 15229.1138, 15216.3965, 15229.1830, 15227.2452, 15219.0586, ]
====> Epoch: 41 Average loss: 15219.4284
Test set loss: [15218.4504, 15236.8518, 15233.7760, 15231.4932, 15229.9634, 15217.1933, 15229.7106, 15227.7806, 15219.1211, ]
====> Epoch: 42 Average loss: 15219.9529
Test set loss: [15217.2425, 15235.5725, 15232.7132, 15230.1435, 15229.0638, 15216.3499, 15228.6195, 15226.5144, 15218.8439, ]
====> Epoch: 43 Average loss: 15218.6139
Test set loss: [15219.4765, 15236.9672, 15234.2776, 15231.9604, 15231.1030, 15217.9172, 15230.6413, 15228.1709, 15221.0479, ]
====> Epoch: 44 Average loss: 15218.0459
Test set loss: [15219.2208, 15237.7365, 15234.7432, 15232.9139, 15231.0706, 15218.2641, 15231.2882, 15228.9907, 15221.7409, ]
====> Epoch: 45 Average loss: 15218.0382
Test set loss: [15217.1530, 15235.7906, 15232.6919, 15230.3914, 15229.1043, 15216.4219, 15229.1680, 15226.2641, 15219.4115, ]
====> Epoch: 46 Average loss: 15217.1424
Test set loss: [15218.3777, 15235.9910, 15232.9770

====> Epoch: 60 Average loss: 15211.6371
Test set loss: [15223.0204, 15236.8460, 15234.6872, 15231.9538, 15231.0245, 15219.5063, 15230.8494, 15229.5985, 15223.9682, ]
====> Epoch: 61 Average loss: 15212.0943
Test set loss: [15225.1959, 15238.7543, 15235.7283, 15233.9635, 15234.7318, 15224.6030, 15233.4788, 15232.5044, 15230.5709, ]
====> Epoch: 62 Average loss: 15212.4113
Test set loss: [15222.0728, 15236.1681, 15234.1411, 15231.1590, 15231.0109, 15219.0982, 15230.2535, 15229.8307, 15225.2750, ]
====> Epoch: 63 Average loss: 15210.7076
Test set loss: [15224.7698, 15237.3972, 15234.9279, 15232.2659, 15232.7799, 15221.6920, 15231.8646, 15231.6960, 15227.5169, ]
====> Epoch: 64 Average loss: 15211.6868
Test set loss: [15222.0743, 15236.2795, 15233.8402, 15231.1285, 15230.7012, 15220.5966, 15230.1470, 15230.3823, 15225.5326, ]
====> Epoch: 65 Average loss: 15211.0141
Test set loss: [15223.6656, 15236.3038, 15233.9380, 15231.8360, 15231.8214, 15221.2303, 15230.6907, 15230.8489, 15225.4394, 

15237.4718, 15233.0496, ]
====> Epoch: 80 Average loss: 15206.7776
Test set loss: [15231.8291, 15238.2588, 15235.6147, 15232.4346, 15234.1515, 15225.5176, 15233.3468, 15235.2985, 15231.2961, ]
====> Epoch: 81 Average loss: 15206.1695
Test set loss: [15232.9491, 15237.7744, 15236.7722, 15234.1323, 15235.5728, 15226.7232, 15234.4468, 15236.2996, 15237.4903, ]
====> Epoch: 82 Average loss: 15204.4489
Test set loss: [15231.0120, 15238.5600, 15236.5641, 15233.0386, 15234.2761, 15225.6355, 15233.0411, 15235.1001, 15230.1587, ]
====> Epoch: 83 Average loss: 15206.0252
Test set loss: [15229.2099, 15238.5197, 15236.0512, 15233.3301, 15234.1230, 15226.7411, 15233.6210, 15235.4699, 15234.5603, ]
====> Epoch: 84 Average loss: 15205.6204
Test set loss: [15229.0604, 15236.9098, 15235.0285, 15232.0501, 15233.5148, 15226.1821, 15232.3935, 15235.6032, 15234.0652, ]
====> Epoch: 85 Average loss: 15203.7980
Test set loss: [15228.5157, 15236.4178, 15234.7979, 15230.9892, 15231.7916, 15224.6293, 15230.9701

15231.1860, 15232.8546, 15228.2245, 15231.5245, 15236.5215, 15234.1919, ]
====> Epoch: 100 Average loss: 15200.6278
Test set loss: [15236.2091, 15238.5571, 15238.6511, 15233.9654, 15236.3086, 15233.1301, 15235.7608, 15244.1993, 15246.0118, ]


In [0]:
training_acc_arr = np.empty(num_epochs)
testing_acc_arr = np.empty((9, num_epochs))

for epoch in range(num_epochs):
  
  net.train()
  
  for i, (signals, labels) in enumerate(train_loader):
    
    signals = signals.type(torch.FloatTensor)
    signals = Variable(signals)
    labels = labels.type(torch.LongTensor)
    labels = Variable(torch.squeeze(labels))
    
    if use_cuda and torch.cuda.is_available():
      signals = signals.cuda()
      labels = labels.cuda()
    
    if epoch > 80:
      signals, _, _ = model(signals)
      signals = signals.view(-1, 1, 22, 1000)
    
    optimizer.zero_grad()
    outputs = net(signals)
    
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    
    if (i+1) % 17 == 0:
      print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f' 
            % (epoch+1, num_epochs, i+1, len(EEGset.train_dataset)//batch_size, 
               loss.data[0]))
  
  net.eval()
  
  # Training accuracy
  total = 0
  correct = 0
  for signals, labels in train_loader:
    signals = signals.type(torch.FloatTensor)
    signals = Variable(signals)
    labels = torch.squeeze(labels.type(torch.LongTensor))
    if use_cuda and torch.cuda.is_available():
      signals = signals.cuda()
      labels = labels.cuda()
    outputs = net(signals)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  training_acc_arr[epoch] = (correct/total)
  print ('Training Accuracy: %.5f' % training_acc_arr[epoch])
  
  
  # Testing accuracy
  for subject in range(9):
    total = 0
    correct = 0
    for signals, labels in test_loader[str(subject+1)]:
      signals = signals.type(torch.FloatTensor)
      signals = Variable(signals)
      labels = torch.squeeze(labels.type(torch.LongTensor))
      if use_cuda and torch.cuda.is_available():
        signals = signals.cuda()
        labels = labels.cuda()
      outputs = net(signals)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum()
    testing_acc_arr[subject, epoch] = (correct/total)
  print ('Testing Accuracy: ' + str(testing_acc_arr[:, epoch]))
  print ('Testing Accuracy Average: %.5f' % np.average(testing_acc_arr[:, epoch]))

Epoch [1/200], Step [17/68], Loss: 1.3801
Epoch [1/200], Step [34/68], Loss: 1.3819
Epoch [1/200], Step [51/68], Loss: 1.3932
Epoch [1/200], Step [68/68], Loss: 1.3790
Training Accuracy: 0.25522
Testing Accuracy: [0.22 0.2  0.22 0.24 0.2  0.28 0.22 0.22 0.24]
Testing Accuracy Average: 0.22667
Epoch [2/200], Step [17/68], Loss: 1.3914
Epoch [2/200], Step [34/68], Loss: 1.3771
Epoch [2/200], Step [51/68], Loss: 1.3900
Epoch [2/200], Step [68/68], Loss: 1.4102
Training Accuracy: 0.25522
Testing Accuracy: [0.22 0.2  0.22 0.24 0.2  0.28 0.22 0.22 0.24]
Testing Accuracy Average: 0.22667
Epoch [3/200], Step [17/68], Loss: 1.3862
Epoch [3/200], Step [34/68], Loss: 1.3920
Epoch [3/200], Step [51/68], Loss: 1.3954
Epoch [3/200], Step [68/68], Loss: 1.3966
Training Accuracy: 0.25522
Testing Accuracy: [0.22 0.2  0.22 0.24 0.2  0.28 0.22 0.22 0.24]
Testing Accuracy Average: 0.22667
Epoch [4/200], Step [17/68], Loss: 1.3919
Epoch [4/200], Step [34/68], Loss: 1.3885
Epoch [4/200], Step [51/68], Loss: