In [1]:
import h5py
import numpy as np
from pathlib import Path
import torch
from torch.utils import data
import matplotlib.pyplot as plt
import SimpleITK as sitk
import os
import torch
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch.optim import Adam
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
#from torchsummary import summary
import torch.utils.data as utils
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
label_encoder = LabelEncoder()
from torch.utils.data import Dataset
from tqdm import tqdm
import h5py

In [2]:
class HDF5Dataset(data.Dataset):
    """Represents an abstract HDF5 dataset.
    
    Input params:
        file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
            the dataset is fits into memory. Otherwise, leave this at false and 
            the data will load lazily.
        data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
    """
    def __init__(self, file_path,  data_cache_size):
        super().__init__()
        self.data_info = []
        self.data_cache = {}
        self.data_cache_size = data_cache_size

        # Search for all h5 files
        p = Path(file_path)
        assert(p.is_dir())

        files = sorted(p.glob('*.h5'))
        if len(files) < 1:
            raise RuntimeError('No hdf5 datasets found')

        for h5dataset_fp in files:
            self._add_data_infos(str(h5dataset_fp.resolve()))
            
    def __getitem__(self, index):
        # get data
        x = self.get_data("data", index)

        x = torch.from_numpy(x)

        # get label
        y = self.get_data("label", index)
        y = torch.from_numpy(y)
        return (x, y)

    def __len__(self):
        return len(self.get_data_infos('data'))
    
    def _add_data_infos(self, file_path):
        with h5py.File(file_path) as h5_file:
            # Walk through all groups, extracting datasets
            for gname, group in h5_file.items():
                for dname, ds in group.items():
                    # if data is not loaded its cache index is -1
                    idx = -1
                    
                    # type is derived from the name of the dataset; we expect the dataset
                    # name to have a name such as 'data' or 'label' to identify its type
                    # we also store the shape of the data in case we need it
                    self.data_info.append({'file_path': file_path, 'type': dname, 'shape': ds.value.shape, 'cache_idx': idx})

    def _load_data(self, file_path):
        """Load data to the cache given the file
        path and update the cache index in the
        data_info structure.
        """
        with h5py.File(file_path) as h5_file:
            for gname, group in h5_file.items():
                for dname, ds in group.items():
                    # add data to the data cache and retrieve
                    # the cache index
                    idx = self._add_to_cache(ds.value, file_path)

                    # find the beginning index of the hdf5 file we are looking for
                    file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)

                    # the data info should have the same index since we loaded it in the same way
                    self.data_info[file_idx + idx]['cache_idx'] = idx

        # remove an element from data cache if size was exceeded
        if len(self.data_cache) > self.data_cache_size:
            # remove one item from the cache at random
            removal_keys = list(self.data_cache)
            removal_keys.remove(file_path)
            self.data_cache.pop(removal_keys[0])
            # remove invalid cache_idx
            self.data_info = [{'file_path': di['file_path'], 'type': di['type'], 'shape': di['shape'], 'cache_idx': -1} if di['file_path'] == removal_keys[0] else di for di in self.data_info]

    def _add_to_cache(self, data, file_path):
        """Adds data to the cache and returns its index. There is one cache
        list for every file_path, containing all datasets in that file.
        """
        if file_path not in self.data_cache:
            self.data_cache[file_path] = [data]
        else:
            self.data_cache[file_path].append(data)
        return len(self.data_cache[file_path]) - 1

    def get_data_infos(self, type):
        """Get data infos belonging to a certain type of data.
        """
        data_info_type = [di for di in self.data_info if di['type'] == type]
        return data_info_type

    def get_data(self, type, i):
        """Call this function anytime you want to access a chunk of data from the
            dataset. This will make sure that the data is loaded in case it is
            not part of the data cache.
        """
        fp = self.get_data_infos(type)[i]['file_path']
        if fp not in self.data_cache:
            self._load_data(fp)
        
        # get new cache_idx assigned by _load_data_info
        cache_idx = self.get_data_infos(type)[i]['cache_idx']
        return self.data_cache[fp][cache_idx]

In [3]:
dataset = HDF5Dataset('D:/ADNI/Dati/ADNI_T1/ADNI1_T1/ADNI_Registrate/H5', data_cache_size=40)




In [4]:
loader_params = {'batch_size': 1, 'shuffle': False, 'num_workers': 0}

In [5]:
data_loader = data.DataLoader(dataset, **loader_params)

In [None]:
#data_loader.dataset.get_data("data",0).shape

In [None]:
data_loader.dataset.get_data_infos("data")

In [None]:
dataset.__len__()

In [None]:
for a, b in data_loader:
    print(a.shape)

In [None]:
dataset.data_cache_size

In [None]:
data_loader.dataset.get_data("data", 1)

In [None]:
plt.imshow(data_loader.dataset.get_data("data", 28)[5][:,160,:])

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv3d(1, 10, kernel_size=(5, 5, 5))
        self.conv2 = nn.Conv3d(10, 20, kernel_size=(5, 5, 5))
        self.conv2_drop = nn.Dropout3d()
        self.fc1 = nn.Linear(100, 50)
        self.fc2 = nn.Linear(50,3)
        
    def forward(self, x):
        x = F.relu(F.max_pool3d(self.conv1(x), (2, 2, 2)))
        x = F.relu(F.max_pool3d(self.conv2_drop(self.conv2(x)), (2, 2, 2)))
        x = x.view(-1, 100)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = Net().to(device)

In [8]:
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)

In [9]:
for i in range(0,1):
    loss_list, batch_list = [], []

    for num, (img, label) in enumerate(data_loader):
        print()
        inputs = img.squeeze(0).unsqueeze(1)

        b = torch.split(inputs, 10, dim=0)
        label_new  = torch.split(label.squeeze(0), 10, dim=0)
        print(label_new[0].shape)
        for i,a in enumerate(b):
            print(a.size())


            output = model(a.float().to(torch.device("cuda" if torch.cuda.is_available() else "cpu")))
            print(output.shape)
            loss = criterion(output, label_new[i].float().to(torch.device("cuda" if torch.cuda.is_available() else "cpu")))
            optimizer.zero_grad()
            loss_list.append(loss.detach().cpu().item())
            batch_list.append(i+1)
            loss.backward()
            optimizer.step()
            print('Train - Epoch %d, Batch: %d, Loss: %f' % (i, num, loss.detach().cpu().item()))


torch.Size([10, 3])
torch.Size([10, 1, 166, 256, 256])
torch.Size([282796, 3])


  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (282796) must match the size of tensor b (10) at non-singleton dimension 0

In [None]:
torch.cuda.empty_cache()