In [1]:
import numpy as np
import os
import scipy.io
from scipy import signal
from scipy.integrate import simps

In [2]:
def loadData(folderName):
    data = []
    labels = []
    allFiles = os.listdir(folderName)
    for file in allFiles:
        loadedFile = scipy.io.loadmat(folderName+'/'+file)
        data.append(loadedFile['data'])
        labels.append(loadedFile['labels'])

    return np.array(data), np.array(labels)

In [3]:
data, labels = loadData("../Datasets/data_preprocessed_matlab")

In [4]:
# subject * video * channel * data
print(data.shape)
# subject * video * labels
print(labels.shape)

(32, 40, 40, 8064)
(32, 40, 4)


In [5]:
def getSpecificChannels():
    channels = np.array(['Fp1', 'AF3', 'F3', 'F7', 'FC5', 'FC1', 'C3', 'T7', 'CP5', 'CP1', 'P3', 'P7', 'PO3', 'O1', 'Oz',
                        'Pz', 'Fp2', 'AF4', 'Fz', 'F4', 'F8', 'FC6', 'FC2', 'Cz', 'C4', 'T8', 'CP6', 'CP2', 'P4', 'P8', 'PO4', 'O2'])
    selectedChannel = np.array(
        ['AF3', 'F3', 'F7', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F8', 'F4', 'AF4'])
    return np.where(np.in1d(channels, selectedChannel))[0]

In [6]:
channelIndices = getSpecificChannels()
channelIndices

array([ 1,  2,  3,  4,  7, 11, 13, 17, 19, 20, 21, 25, 29, 31])

In [7]:
# Select 14 channel
selectedChannelData = data[:,:,channelIndices,:]
print(selectedChannelData.shape)

(32, 40, 14, 8064)


In [8]:
# Remove first 20 second of video
sampleRate = 128
removed20SecData = selectedChannelData[:,:,:,sampleRate*20:sampleRate*60]

In [9]:
removed20SecData.shape

(32, 40, 14, 5120)

In [10]:
processedData = np.zeros((32,160,14,1280))
parts = [removed20SecData[:,:,:,:sampleRate*10],
         removed20SecData[:,:,:,sampleRate*10:sampleRate*20],
         removed20SecData[:,:,:,sampleRate*20:sampleRate*30],
         removed20SecData[:,:,:,sampleRate*30:]
        ]

for s in range(32):
    for v in range(40):
        for i in range(4):
            processedData[s,(v * 4)+i, :,:] = parts[i][s,v,:,:]
            
processedLabels = np.repeat(labels,4, axis=1)

In [11]:
print(processedData.shape)
print(processedLabels.shape)

(32, 160, 14, 1280)
(32, 160, 4)


In [12]:
valenceLabels = ((processedLabels[:, :, 0] >= 5) * 1)
arousalLabels = ((processedLabels[:, :, 1] >= 5) * 1)
dominanceLabels = ((processedLabels[:, :, 2] >= 5) * 1)

In [13]:
subjectCount = 32
channelCount = 14
videoCount = 160

In [16]:
import pywt

def waveletFunc(data,subjectCount, videoCount,channelCount):
    
    subjectFeatures = []
    for s in range(subjectCount):
        
        videoFeatures = []
        for v in range(videoCount):
            
            allChannelsFeatures = []
            for ch in range(channelCount):
                channelFeatures =[]
                EntAllChannel =[]
                EngAllChannel = []
            
                cA5, cD4, cD3, cD2, cD1  = pywt.wavedec(data[s][v][ch],'db4',mode='symmetric',level= 4)
            
                Eng1 = np.square(cD1).sum()
                EngAllChannel.append(Eng1)
        
                Eng2 = np.square(cD2).sum()
                EngAllChannel.append(Eng2)
        
                Eng3 = np.square(cD3).sum()
                EngAllChannel.append(Eng3)
        
                Eng4 = np.square(cD4).sum()
                EngAllChannel.append(Eng4)
        
                channelFeatures.append(EngAllChannel)
                allChannelsFeatures.append(channelFeatures)
            videoFeatures.append(allChannelsFeatures)
        subjectFeatures.append(videoFeatures)
        
    return np.array(subjectFeatures)

In [17]:
features = waveletFunc(processedData,subjectCount,videoCount,channelCount)

In [18]:
features.shape

(32, 160, 14, 1, 4)

In [19]:
features = features.reshape(features.shape[0],features.shape[1],features.shape[2]*features.shape[3]*features.shape[4])

In [20]:
features.shape

(32, 160, 56)

In [21]:
import pickle
waveletFeaturesFolder = 'waveletFeatures.dat'
with open(waveletFeaturesFolder,'wb') as f:
    pickle.dump(features,f)

In [22]:
with open(waveletFeaturesFolder,'rb') as f:
    waveletFeatures = pickle.load(f)

In [23]:
print(waveletFeatures.shape)

(32, 160, 56)
