In [1]:
import h5py
import os
import pandas as pd
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
from tqdm import tqdm

# librerie pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch.optim import Adam
from torch.utils import data
import torch.utils.data as utils
from torch.utils.data import Dataset
from torchvision import transforms

from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import OneHotEncoder
label_encoder = LabelEncoder()

from scipy import stats
import torchvision.models as models


In [2]:
Comp = 'Poli'

In [3]:
class HDF5Dataset(data.Dataset):
    """Represents an abstract HDF5 dataset.
    
    Input params:
        file_path: Path to the folder containing the dataset (one or multiple HDF5 files).
            the dataset is fits into memory. Otherwise, leave this at false and 
            the data will load lazily.
        data_cache_size: Number of HDF5 files that can be cached in the cache (default=3).
    """
    def __init__(self, file_path,  data_cache_size):
        super().__init__()
        self.data_info = []
        self.data_cache = {}
        self.data_cache_size = data_cache_size

        # Search for all h5 files
        p = Path(file_path)
        assert(p.is_dir())

        files = sorted(p.glob('*.h5'))
        if len(files) < 1:
            raise RuntimeError('No hdf5 datasets found')

        for h5dataset_fp in files:
            self._add_data_infos(str(h5dataset_fp.resolve()))
            
    def __getitem__(self, index):
        # get data
        x = self.get_data("data", index)

        x = torch.from_numpy(x)

        # get label
        y = self.get_data("label", index)
        y = torch.from_numpy(y)
        return (x, y)

    def __len__(self):
        return len(self.get_data_infos('data'))
    
    def _add_data_infos(self, file_path):
        with h5py.File(file_path) as h5_file:
            # Walk through all groups, extracting datasets
            for gname, group in h5_file.items():
                for dname, ds in group.items():
                    # if data is not loaded its cache index is -1
                    idx = -1
                    
                    # type is derived from the name of the dataset; we expect the dataset
                    # name to have a name such as 'data' or 'label' to identify its type
                    # we also store the shape of the data in case we need it
                    self.data_info.append({'file_path': file_path, 'type': dname, 'shape': ds.value.shape, 'cache_idx': idx})

    def _load_data(self, file_path):
        """Load data to the cache given the file
        path and update the cache index in the
        data_info structure.
        """
        with h5py.File(file_path) as h5_file:
            for gname, group in h5_file.items():
                for dname, ds in group.items():
                    # add data to the data cache and retrieve
                    # the cache index
                    idx = self._add_to_cache(ds.value, file_path)

                    # find the beginning index of the hdf5 file we are looking for
                    file_idx = next(i for i,v in enumerate(self.data_info) if v['file_path'] == file_path)

                    # the data info should have the same index since we loaded it in the same way
                    self.data_info[file_idx + idx]['cache_idx'] = idx

        # remove an element from data cache if size was exceeded
        if len(self.data_cache) > self.data_cache_size:
            # remove one item from the cache at random
            removal_keys = list(self.data_cache)
            removal_keys.remove(file_path)
            self.data_cache.pop(removal_keys[0])
            # remove invalid cache_idx
            self.data_info = [{'file_path': di['file_path'], 'type': di['type'], 'shape': di['shape'], 'cache_idx': -1} if di['file_path'] == removal_keys[0] else di for di in self.data_info]

    def _add_to_cache(self, data, file_path):
        """Adds data to the cache and returns its index. There is one cache
        list for every file_path, containing all datasets in that file.
        """
        if file_path not in self.data_cache:
            self.data_cache[file_path] = [data]
        else:
            self.data_cache[file_path].append(data)
        return len(self.data_cache[file_path]) - 1

    def get_data_infos(self, type):
        """Get data infos belonging to a certain type of data.
        """
        data_info_type = [di for di in self.data_info if di['type'] == type]
        return data_info_type

    def get_data(self, type, i):
        """Call this function anytime you want to access a chunk of data from the
            dataset. This will make sure that the data is loaded in case it is
            not part of the data cache.
        """
        fp = self.get_data_infos(type)[i]['file_path']
        if fp not in self.data_cache:
            self._load_data(fp)
        
        # get new cache_idx assigned by _load_data_info
        cache_idx = self.get_data_infos(type)[i]['cache_idx']
        return self.data_cache[fp][cache_idx]

In [4]:
class dataset_h5(torch.utils.data.Dataset):
    
    def __init__(self, in_file, imgs_key='data', labels_key='label', transform = None):
        super(dataset_h5, self).__init__()

        self.in_file = in_file
        self.imgs_key = imgs_key
        self.labels_key = labels_key
        self.transform = transform

    def __getitem__(self, index):
        with h5py.File(self.in_file,'r') as db:
            img_data = db.file[self.imgs_key][index]
            label = db.file[self.labels_key][index]
            if self.transform:
                img_data = self.transform(img_data)
        
        return (img_data, label, index)

    def __len__(self):
        with h5py.File(self.in_file,'r') as db:
            lens=len(db[self.labels_key])
            
        return lens

In [5]:
 with h5py.File('D:/ADNI/Dati/ADNI_T1/ADNI1_T1/ADNI_Registrate/H5Corrette/data_img_norm_bilanciate.h5','r') as db:
        print(db['label'])

<HDF5 dataset "label": shape (618, 1), type "<f4">


In [6]:
if Comp == 'FabioPC':
    dataset = dataset_h5('/media/fabio/Disco locale/Fabio/Programmazione/Python/Poliambulanza/Alzheimer/Dati/ADNI/ADNI_Prova/H5Corrette/data_img.h5', transform)
else:
    dataset = dataset_h5('D:/ADNI/Dati/ADNI_T1/ADNI1_T1/ADNI_Registrate/H5Corrette/data_img_norm_intere.h5')

In [7]:
from torch.utils.data.sampler import SubsetRandomSampler

dataset_size = len(dataset)
indices = list(range(dataset_size))
validation_split = .2
shuffle_dataset = True
random_seed= 42

split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
    
train_indices, val_indices = indices[split:], indices[:split]
print(train_indices)
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

[181, 531, 364, 177, 593, 199, 421, 300, 408, 396, 286, 257, 259, 610, 42, 73, 537, 66, 11, 163, 443, 210, 332, 83, 278, 448, 79, 23, 287, 301, 280, 244, 290, 329, 137, 188, 165, 296, 9, 196, 231, 319, 604, 265, 84, 174, 547, 218, 316, 490, 390, 213, 153, 75, 92, 430, 68, 15, 192, 375, 88, 518, 117, 409, 434, 33, 0, 611, 553, 369, 355, 453, 551, 22, 472, 116, 89, 182, 495, 411, 18, 428, 535, 144, 302, 565, 557, 425, 272, 261, 362, 429, 167, 54, 441, 46, 93, 304, 108, 292, 195, 617, 513, 467, 370, 407, 7, 412, 423, 284, 275, 581, 69, 264, 432, 298, 249, 572, 274, 149, 124, 607, 530, 185, 333, 312, 477, 310, 609, 31, 586, 506, 568, 486, 141, 19, 172, 483, 482, 25, 446, 589, 605, 318, 245, 338, 154, 126, 367, 113, 173, 57, 344, 222, 17, 320, 255, 327, 591, 190, 341, 543, 291, 94, 180, 395, 354, 550, 334, 5, 45, 574, 416, 525, 16, 48, 597, 563, 3, 349, 555, 469, 388, 464, 394, 225, 26, 583, 263, 50, 229, 37, 157, 237, 592, 175, 519, 436, 194, 521, 383, 596, 527, 67, 414, 168, 500, 162, 309

In [8]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=1, 
                                           sampler=train_sampler)

In [9]:
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=3,
                                                sampler=valid_sampler)

In [None]:
#loader_params = {'batch_size': 15, 'shuffle': False, 'num_workers': 0}

In [None]:
#data_loader = data.DataLoader(dataset, **loader_params)

In [None]:
#data_loader.dataset.get_data("data",0).shape

In [None]:
#data_loader.dataset.get_data_infos("data")

In [None]:
for batch_index, (img, labels, index) in enumerate(train_loader):
    print(labels, index, img.shape)


In [None]:
#dataset.data_cache_size

In [None]:
labels

In [None]:
fig, axs = plt.subplots(2, 2)
axs[0, 0].imshow(img[0][80,:,:])
axs[0, 1].imshow(img[1][80,:,:])
axs[1, 0].imshow(img[2][80,:,:])

In [14]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv3d(1, 5, kernel_size=(3, 3, 3))
        self.conv2 = nn.Conv3d(5, 10, kernel_size=(3, 4, 4))
        
        self.conv3 = nn.Conv3d(10, 15, kernel_size=(3, 5, 5))
        self.conv4 = nn.Conv3d(15, 20, kernel_size=(4, 4, 4))
        
        
        
        self.fc1 = nn.Linear( 8 * 13 * 13 * 20, 50) # 38 * 61 * 61 * 20
        self.fc2 = nn.Linear(50, 20)
        self.fc3 = nn.Linear(20,3)
        
    def forward(self, x):
        x = F.max_pool3d(F.relu(self.conv1(x)), (2, 2, 2))
        x = F.max_pool3d(F.relu(self.conv2(x)), (2, 2, 2))
        
        x =F.max_pool3d(F.relu(self.conv3(x)), (2, 2, 2))
        x = F.max_pool3d(F.relu(self.conv4(x)), (2, 2, 2))
            
            
        
        x = x.view(-1,13 * 13 * 20 * 8)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.softmax(x)

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
model = Net().to(device)

In [16]:
import torch.optim as optim
weights = [1, 193/225, 194/199]
class_weights = torch.FloatTensor(weights).cuda()

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [17]:
for i in range(0,50):
    
    loss_list, batch_list = [], []
    total_acc = 0
    total_loss = 0
    
    for num, (img, label, index) in enumerate(train_loader):
        
        total = 0
        correct=0
        
        inputs = img.unsqueeze(1)
        print(inputs.shape)
        label = np.argmax(label, axis=1)

        output = model(inputs.float().to(torch.device("cuda" if torch.cuda.is_available() else "cpu")))
        loss = criterion(output, label.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"), dtype= torch.long))
        _, predicted = torch.max(output, 1)
        print()
        print(output)
        print(label)
        
        
        total += label.size(0)
        correct += (predicted == label.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))).sum().item()
        
        optimizer.zero_grad()
        loss_list.append(loss.detach().cpu().item())
        batch_list.append(i+1)
        loss.backward()
        optimizer.step()
        
        total_acc +=correct
        total_loss += loss.detach().cpu().item()
        if num%1==0:
            print('Train - Epoch %d, Batch: %d, Loss: %f, Acc: %d' % (i, num, loss.detach().cpu().item(), correct))
    print('ACC = ' +str(total_acc/617))
    print('Loss = ' + str(total_loss/617))

torch.Size([1, 1, 166, 256, 256])





tensor([[0.3308, 0.3458, 0.3234]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([1])
Train - Epoch 0, Batch: 0, Loss: 1.086190, Acc: 1
torch.Size([1, 1, 166, 256, 256])

tensor([[0.3295, 0.3473, 0.3232]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([0])
Train - Epoch 0, Batch: 1, Loss: 1.102450, Acc: 0
torch.Size([1, 1, 166, 256, 256])

tensor([[0.3317, 0.3503, 0.3181]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([0])
Train - Epoch 0, Batch: 2, Loss: 1.100348, Acc: 0
torch.Size([1, 1, 166, 256, 256])

tensor([[0.3344, 0.3501, 0.3155]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([2])
Train - Epoch 0, Batch: 3, Loss: 1.116569, Acc: 0
torch.Size([1, 1, 166, 256, 256])

tensor([[0.3346, 0.3495, 0.3159]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([2])
Train - Epoch 0, Batch: 4, Loss: 1.116124, Acc: 0
torch.Size([1, 1, 166, 256, 256])

tensor([[0.3368, 0.3460, 0.3173]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([1])
Train - Epoch 0, Batch: 5, L

KeyboardInterrupt: 

In [None]:
label

In [None]:
np.argmax(label, axis=1)

In [None]:
PATH = os.getcwd() + '/LeNet5_CNN.pth'
torch.save(model.state_dict(), PATH)

In [None]:
net = Net().to(device)
net.load_state_dict(torch.load(PATH))

In [None]:
dataiter = iter(validation_loader)

In [None]:
images, labels = dataiter.next()

In [None]:
asd

In [None]:
total = 0
correct=0
with torch.no_grad():
    for data in validation_loader:
        images, labels = data
        print(labels)
        outputs = net(images.unsqueeze(1).float().to(torch.device("cuda" if torch.cuda.is_available() else "cpu")))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        print(predicted.to("cpu"))
        correct += (predicted.to("cpu") == np.argmax(labels, axis=1)).sum().item()
        print(correct)

print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))