In [1]:
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import nibabel as nbl

In [2]:
HCPDataListPath = './s2_con1_label.txt'
HCPDataRootPath = '../../data/disk/'
labelNumber = {'EMOTION': 0, 'GAMBLING': 1, 'LANGUAGE': 2, 'MOTOR': 3,
               'RELATIONAL': 4, 'SOCIAL': 5, 'WM': 6}

In [3]:
def getLabelList(HCPDataPathList):
    labelList = []
    for HCPDataPath in HCPDataPathList:
        labelList.append(labelNumber[HCPDataPath.split('/')[5].split('_')[1]])
    return np.array(labelList)

In [4]:
def getDataList(HCPDataListPath):
    HCPDataPathList = []
    with open(HCPDataListPath, 'r') as fr:
        for HCPDataPath in fr.readlines():
            HCPDataPathList.append(HCPDataPath.strip())
    return HCPDataPathList

In [5]:
def getData(HCPDataRootPath, dataPath):
    dtseries = nbl.load(HCPDataRootPath + dataPath)
    time_series = dtseries.get_fdata().reshape((-1))
    shape = dtseries.header.matrix.get_index_map(1).volume.volume_dimensions
    nifti = np.zeros(shape)
    for bm in dtseries.header.matrix.get_index_map(1).brain_models:
        if bm.model_type == 'CIFTI_MODEL_TYPE_SURFACE':
            continue
        voxels = bm.voxel_indices_ijk
        off, cnt = bm.index_offset, bm.index_count
        nifti[tuple(np.transpose(voxels))] = time_series[off:off + cnt]
    return np.array(nifti).reshape((1, 91, 109, 91)).astype(np.float32)

In [6]:
class EntireDataSet():
    def __init__(self, trainDataSet, evalDataSet, testDataSet):
        self.trainDataSet = trainDataSet
        self.evalDataSet = evalDataSet
        self.testDataSet = testDataSet

In [7]:
class HCPDataSet(Dataset):
    def __init__(self, HCPDataRootPath, HCPDataPathList, HCPLabelList):
        self.HCPDataRootPath = HCPDataRootPath
        if int(len(HCPDataPathList)) != int(len(HCPLabelList)):
            print("Error: dataPathList {}, labelList {}"
                  .format(len(HCPDataPathList), len(HCPLabelList)))
        self.HCPDataPathList = np.array(HCPDataPathList)
        self.HCPLabelList = np.array(HCPLabelList)
        self.total = len(HCPDataPathList)

    def __getitem__(self, index):
        dataPath = self.HCPDataPathList[index]
        data = getData(self.HCPDataRootPath, dataPath)
        label = np.array(self.HCPLabelList[index])
        return torch.from_numpy(data), torch.from_numpy(label)

    def __len__(self):
        return self.total

In [8]:
def getHCPDataSet(HCPDataRootPath, HCPDataListPath, evalRate=0.2, testRate=0.2, tiny_data=0):
    '''
    This function will return a object of the class HCPDataSet.
    Parameters:
        HCPDataRootPath: (string) The root path of HCP data,
                         for example '../../data/disk/'.
        HCPDataListPath: (string) The txt file of the pathes,
                         for example './label.txt'.
        rate: (double) The rate of training dataset in entire dateset.
        tiny_data: (int) Use data[:tiny_data] if tiny_data != 0.
    Result:
        HCPDataSet: A object of HCPDataSet.
    '''
    dataPathList = getDataList(HCPDataListPath)
    if tiny_data != 0: 
        dataPathList = dataPathList[:int(tiny_data)]
        
    totalNumber = len(dataPathList)
    totalTraining = int(totalNumber * (1 - evalRate - testRate))
    totalEvaluation = int(totalNumber * evalRate)
    print('Training: {}, Evaluation: {}, Test: {}'
          .format(totalTraining, totalEvaluation, totalNumber - totalTraining - totalEvaluation))
    
    trainDataPathList = dataPathList[:totalTraining]
    trainLabelList = getLabelList(trainDataPathList)
    
    evalDataPathList = dataPathList[totalTraining:totalTraining + totalEvaluation]
    evalLabelList = getLabelList(evalDataPathList)
    
    testDataPathList = dataPathList[totalTraining + totalEvaluation:]
    testLabelList = getLabelList(testDataPathList)

    trainDataSet = HCPDataSet(HCPDataRootPath, trainDataPathList, trainLabelList)
    evalDataSet = HCPDataSet(HCPDataRootPath, evalDataPathList, evalLabelList)
    testDataSet = HCPDataSet(HCPDataRootPath, testDataPathList, testLabelList)
    
    dataSet = EntireDataSet(trainDataSet, evalDataSet, testDataSet)
    return dataSet