In [1]:
import numpy as np
import nibabel as nbl

In [2]:
HCPDataListPath = './s2_con1_label.txt'
HCPDataRootPath = '../../data/disk/'
labelNumber = {'EMOTION': 0, 'GAMBLING': 1, 'LANGUAGE': 2, 'MOTOR': 3,
               'RELATIONAL': 4, 'SOCIAL': 5, 'WM': 6}

In [3]:
def getLabelList(HCPDataPathList):
    labelList = []
    for HCPDataPath in HCPDataPathList:
        labelList.append(labelNumber[HCPDataPath.split('/')[5].split('_')[1]])
    return labelList

In [4]:
def getDataList(HCPDataListPath):
    HCPDataPathList = []
    with open(HCPDataListPath, 'r') as fr:
        for HCPDataPath in fr.readlines():
            HCPDataPathList.append(HCPDataPath.strip())
    return HCPDataPathList

In [5]:
def getData(dataPath):
    return nbl.load(dataPath).get_fdata()

In [6]:
def getNextX(HCPDataRootPath, HCPDataPathList, index):
    targetDataPathList = HCPDataPathList[index]
    nextX = []
    for targetDataPath in targetDataPathList:
        nextX.append(getData(HCPDataRootPath + targetDataPath))
    return np.array(nextX)

In [7]:
def getNextY(HCPLabelList, index):
    return HCPLabelList[index]

In [8]:
class HCPDataSet():
    def __init__(self, HCPDataRootPath, HCPDataPathList, HCPLabelList):
        self.HCPDataRootPath = HCPDataRootPath
        if int(len(HCPDataPathList)) != int(len(HCPLabelList)):
            print("Error: dataPathList {}, labelList {}"
                  .format(len(HCPDataPathList), len(HCPLabelList)))
        self.HCPDataPathList = np.array(HCPDataPathList)
        self.HCPLabelList = np.array(HCPLabelList)
        self.total = len(HCPDataPathList)

    def nextBatch(self, batchSize):
        index = np.random.randint(0, self.total, batchSize)
        nextX = getNextX(self.HCPDataRootPath, self.HCPDataPathList, index)
        nextY = getNextY(self.HCPLabelList, index)
        return nextX, nextY

In [9]:
class EntireDataSet():
    def __init__(self, trainDataSet, evalDataSet):
        self.trainDataSet = trainDataSet
        self.evalDataSet = evalDataSet

In [10]:
def getHCPDataSet(HCPDataRootPath, HCPDataListPath, rate=0.7):
    '''
    This function will return a object of the class HCPDataSet.
    Parameters:
        HCPDataRootPath: (string) The root path of HCP data,
                         for example '../../data/disk/'.
        HCPDataListPath: (string) The txt file of the pathes,
                         for example './label.txt'.
        rate: (double) The rate of training dataset in entire dateset.
    Result:
        HCPDataSet: A object of HCPDataSet.
    '''
    dataPathList = getDataList(HCPDataListPath)
    totalNumber = len(dataPathList)
    totalTraining = int(totalNumber * rate)
    trainDataPathList = dataPathList[:totalTraining]
    trainLabelList = getLabelList(trainDataPathList)
    evalDataPathList = dataPathList[totalTraining:]
    evalLabelList = getLabelList(evalDataPathList)

    trainDataSet = HCPDataSet(HCPDataRootPath, trainDataPathList, trainLabelList)
    evalDataSet = HCPDataSet(HCPDataRootPath, evalDataPathList, evalLabelList)
    dataSet = EntireDataSet(trainDataSet, evalDataSet)
    return dataSet